|
|
from typing import Dict, List, Any
|
|
|
import torch
|
|
|
import base64
|
|
|
import io
|
|
|
from PIL import Image
|
|
|
from tryon_core import TryOnEngine
|
|
|
from api_utils import prepare_image_for_processing, image_to_base64
|
|
|
|
|
|
class EndpointHandler:
|
|
|
def __init__(self, path=""):
|
|
|
|
|
|
|
|
|
print("Initializing IDM-VTON Handler...")
|
|
|
self.engine = TryOnEngine(load_mode="4bit", enable_cpu_offload=False, fixed_vae=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.engine.load_models()
|
|
|
self.engine.load_processing_models()
|
|
|
print("Handler Initialized!")
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Args:
|
|
|
data (:obj:):
|
|
|
includes the input data and the parameters for the inference.
|
|
|
"""
|
|
|
|
|
|
inputs = data.pop("inputs", data)
|
|
|
human_img_b64 = inputs.get("human_image")
|
|
|
garment_img_b64 = inputs.get("garment_image")
|
|
|
description = inputs.get("garment_description", "a photo of a garment")
|
|
|
category = inputs.get("category", "upper_body")
|
|
|
|
|
|
|
|
|
human_img = Image.open(io.BytesIO(base64.b64decode(human_img_b64)))
|
|
|
garment_img = Image.open(io.BytesIO(base64.b64decode(garment_img_b64)))
|
|
|
|
|
|
|
|
|
human_img = prepare_image_for_processing(human_img)
|
|
|
garment_img = prepare_image_for_processing(garment_img)
|
|
|
|
|
|
|
|
|
generated_images, masked_image = self.engine.generate(
|
|
|
human_img=human_img,
|
|
|
garment_img=garment_img,
|
|
|
garment_description=description,
|
|
|
category=category,
|
|
|
use_auto_mask=True,
|
|
|
use_auto_crop=True,
|
|
|
denoise_steps=30,
|
|
|
seed=42,
|
|
|
num_images=1
|
|
|
)
|
|
|
|
|
|
|
|
|
return [{
|
|
|
"generated_image": image_to_base64(generated_images[0]),
|
|
|
"masked_image": image_to_base64(masked_image)
|
|
|
}]
|
|
|
|