ahmetyaylalioglu commited on
Commit
0fef7b3
·
verified ·
1 Parent(s): d47b21c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from diffusers import DDPMPipeline
2
  import torch
3
  import numpy as np
@@ -7,33 +8,31 @@ import torchvision.transforms as transforms
7
  from PIL import Image
8
  import logging
9
 
10
- #–– Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
 
13
- #–– Check for accelerate
14
  try:
15
  from accelerate import Accelerator
16
  logging.info("Accelerate library found.")
17
  except ImportError:
18
- logging.warning("Accelerate library not found; for large models, 'accelerate' is recommended.")
19
 
20
- #–– Load the DDPM pipeline
21
  MODEL_ID = "ahmetyaylalioglu/textile_diffusion_ddpm"
22
  logging.info(f"Loading model from {MODEL_ID}...")
23
  pipeline = DDPMPipeline.from_pretrained(MODEL_ID)
24
 
25
- #–– Device setup (ZeroGPU → CPU)
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  logging.info(f"Using device: {device}")
28
  pipeline.to(device)
29
- pipeline.unet.to(torch.float32) # ensure CPU‐friendly dtype
30
  pipeline.unet.eval()
31
 
32
- #–– Image generation function
33
  def generate_images(seed, num_images):
34
  try:
35
  seed = int(seed)
36
- num_images = min(int(num_images), 16) # cap to 16 to avoid OOM
37
  logging.info(f"Generating {num_images} images with seed {seed}")
38
 
39
  torch.manual_seed(seed)
@@ -43,10 +42,9 @@ def generate_images(seed, num_images):
43
  grid = make_grid([transforms.ToTensor()(img) for img in imgs], nrow=min(4, num_images))
44
  return transforms.ToPILImage()(grid)
45
  except Exception as e:
46
- logging.error(f"Error in generate_images: {e}")
47
  return None
48
 
49
- #–– Gradio UI
50
  interface = gr.Interface(
51
  fn=generate_images,
52
  inputs=[
@@ -55,8 +53,8 @@ interface = gr.Interface(
55
  ],
56
  outputs="image",
57
  title="Textile Diffusion (DDPM)",
58
- description="Generate textile patterns via a DDPM model from Hugging Face."
59
  )
60
 
61
  if __name__ == "__main__":
62
- interface.launch(share=True)
 
1
+ # app.py
2
  from diffusers import DDPMPipeline
3
  import torch
4
  import numpy as np
 
8
  from PIL import Image
9
  import logging
10
 
 
11
  logging.basicConfig(level=logging.INFO)
12
 
13
+ # Optional: warn if 'accelerate' is missing
14
  try:
15
  from accelerate import Accelerator
16
  logging.info("Accelerate library found.")
17
  except ImportError:
18
+ logging.warning("Accelerate library not found; it's recommended for large models.")
19
 
20
+ # Load DDPM pipeline
21
  MODEL_ID = "ahmetyaylalioglu/textile_diffusion_ddpm"
22
  logging.info(f"Loading model from {MODEL_ID}...")
23
  pipeline = DDPMPipeline.from_pretrained(MODEL_ID)
24
 
25
+ # ZeroGPU → CPU fallback
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  logging.info(f"Using device: {device}")
28
  pipeline.to(device)
29
+ pipeline.unet.to(torch.float32) # ensure CPU dtype
30
  pipeline.unet.eval()
31
 
 
32
  def generate_images(seed, num_images):
33
  try:
34
  seed = int(seed)
35
+ num_images = min(int(num_images), 16)
36
  logging.info(f"Generating {num_images} images with seed {seed}")
37
 
38
  torch.manual_seed(seed)
 
42
  grid = make_grid([transforms.ToTensor()(img) for img in imgs], nrow=min(4, num_images))
43
  return transforms.ToPILImage()(grid)
44
  except Exception as e:
45
+ logging.error(f"generate_images error: {e}")
46
  return None
47
 
 
48
  interface = gr.Interface(
49
  fn=generate_images,
50
  inputs=[
 
53
  ],
54
  outputs="image",
55
  title="Textile Diffusion (DDPM)",
56
+ description="Generate textile patterns via a DDPM model."
57
  )
58
 
59
  if __name__ == "__main__":
60
+ interface.launch(share=True)