ankandrew
commited on
Commit
·
c5c055b
1
Parent(s):
4de9907
Use flash_attention_2 if available
Browse files
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import gradio as gr
|
|
| 3 |
import spaces
|
| 4 |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 5 |
from qwen_vl_utils import process_vision_info
|
| 6 |
-
|
| 7 |
|
| 8 |
subprocess.run(
|
| 9 |
"pip install flash-attn --no-build-isolation",
|
|
@@ -29,7 +29,8 @@ def run_inference(model_key, input_type, text, image, video, fps):
|
|
| 29 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 30 |
model_id,
|
| 31 |
torch_dtype="auto",
|
| 32 |
-
device_map="auto"
|
|
|
|
| 33 |
)
|
| 34 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 35 |
|
|
|
|
| 3 |
import spaces
|
| 4 |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 5 |
from qwen_vl_utils import process_vision_info
|
| 6 |
+
from transformers.utils import is_flash_attn_2_available
|
| 7 |
|
| 8 |
subprocess.run(
|
| 9 |
"pip install flash-attn --no-build-isolation",
|
|
|
|
| 29 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 30 |
model_id,
|
| 31 |
torch_dtype="auto",
|
| 32 |
+
device_map="auto",
|
| 33 |
+
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
|
| 34 |
)
|
| 35 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 36 |
|