r3gm commited on
Commit
7bd100b
·
verified ·
1 Parent(s): 3bc5207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import spaces
3
  import shutil
4
  import subprocess
@@ -11,7 +11,6 @@ import time
11
  import gc
12
  import uuid
13
  from tqdm import tqdm
14
-
15
  import cv2
16
  import numpy as np
17
  import torch
@@ -331,8 +330,14 @@ torch._dynamo.reset()
331
  quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
332
  torch._dynamo.reset()
333
 
334
- aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
335
- aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
 
 
 
 
 
 
336
 
337
  # pipe.vae.enable_slicing()
338
  # pipe.vae.enable_tiling()
@@ -419,14 +424,14 @@ def get_inference_duration(
419
  progress
420
  ):
421
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
422
- BASE_STEP_DURATION = 15
423
  width, height = resized_image.size
424
  factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
425
  step_duration = BASE_STEP_DURATION * factor ** 1.5
426
  gen_time = int(steps) * step_duration
427
 
428
  if guidance_scale > 1:
429
- gen_time = gen_time * 1.9
430
 
431
  frame_factor = frame_multiplier // FIXED_FPS
432
  if frame_factor > 1:
@@ -436,12 +441,12 @@ def get_inference_duration(
436
 
437
  total_time = 15 + gen_time
438
  if safe_mode:
439
- total_time = total_time * 1.20
440
 
441
  return total_time
442
 
443
 
444
- @spaces.GPU(duration=get_inference_duration)
445
  def run_inference(
446
  resized_image,
447
  processed_last_image,
@@ -633,7 +638,6 @@ CSS = """
633
 
634
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
635
  gr.Markdown(model_title())
636
- gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
637
  gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
638
 
639
  with gr.Row():
@@ -649,8 +653,8 @@ with gr.Blocks(delete_cache=(3600, 10800)) as demo:
649
  )
650
  safe_mode_checkbox = gr.Checkbox(
651
  label="🛠️ Safe Mode",
652
- value=False,
653
- info="Requests 20% extra processing time to try to prevent unfinished tasks when the server is busy."
654
  )
655
  with gr.Accordion("Advanced Settings", open=False):
656
  last_image_component = gr.Image(type="pil", label="Last Image (Optional)", sources=["upload", "clipboard"])
 
1
+ import os; os.system('pip install --upgrade --no-deps spaces')
2
  import spaces
3
  import shutil
4
  import subprocess
 
11
  import gc
12
  import uuid
13
  from tqdm import tqdm
 
14
  import cv2
15
  import numpy as np
16
  import torch
 
330
  quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
331
  torch._dynamo.reset()
332
 
333
+ spaces.aoti_load(
334
+ module=pipe.transformer,
335
+ repo_id='cbensimon/WanTransformer3DModel-sm120-cu130-raa',
336
+ )
337
+ spaces.aoti_load(
338
+ module=pipe.transformer_2,
339
+ repo_id='cbensimon/WanTransformer3DModel-sm120-cu130-raa',
340
+ )
341
 
342
  # pipe.vae.enable_slicing()
343
  # pipe.vae.enable_tiling()
 
424
  progress
425
  ):
426
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
427
+ BASE_STEP_DURATION = 5.
428
  width, height = resized_image.size
429
  factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
430
  step_duration = BASE_STEP_DURATION * factor ** 1.5
431
  gen_time = int(steps) * step_duration
432
 
433
  if guidance_scale > 1:
434
+ gen_time = gen_time * 2.4
435
 
436
  frame_factor = frame_multiplier // FIXED_FPS
437
  if frame_factor > 1:
 
441
 
442
  total_time = 15 + gen_time
443
  if safe_mode:
444
+ total_time = total_time * 1.30
445
 
446
  return total_time
447
 
448
 
449
+ @spaces.GPU(duration=get_inference_duration, size='xlarge')
450
  def run_inference(
451
  resized_image,
452
  processed_last_image,
 
638
 
639
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
640
  gr.Markdown(model_title())
 
641
  gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
642
 
643
  with gr.Row():
 
653
  )
654
  safe_mode_checkbox = gr.Checkbox(
655
  label="🛠️ Safe Mode",
656
+ value=True,
657
+ info="Requests 30% extra processing time to try to prevent unfinished tasks when the server is busy."
658
  )
659
  with gr.Accordion("Advanced Settings", open=False):
660
  last_image_component = gr.Image(type="pil", label="Last Image (Optional)", sources=["upload", "clipboard"])