TestModel / handler.py
senku02's picture
Upload 2 files
47ac06f verified
from typing import Dict, List, Any
import torch
import base64
import io
from PIL import Image
from tryon_core import TryOnEngine
from api_utils import prepare_image_for_processing, image_to_base64
class EndpointHandler:
def __init__(self, path=""):
# Initialize the engine
# path is the path to the model files on the HF container
print("Initializing IDM-VTON Handler...")
self.engine = TryOnEngine(load_mode="4bit", enable_cpu_offload=False, fixed_vae=True)
# Override model_id to load from local path if needed,
# or let it download from Hub if path is just a directory
# self.engine.model_id = path
self.engine.load_models()
self.engine.load_processing_models()
print("Handler Initialized!")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
"""
# 1. Extract inputs
inputs = data.pop("inputs", data)
human_img_b64 = inputs.get("human_image")
garment_img_b64 = inputs.get("garment_image")
description = inputs.get("garment_description", "a photo of a garment")
category = inputs.get("category", "upper_body")
# 2. Decode images
human_img = Image.open(io.BytesIO(base64.b64decode(human_img_b64)))
garment_img = Image.open(io.BytesIO(base64.b64decode(garment_img_b64)))
# 3. Process
human_img = prepare_image_for_processing(human_img)
garment_img = prepare_image_for_processing(garment_img)
# 4. Generate
generated_images, masked_image = self.engine.generate(
human_img=human_img,
garment_img=garment_img,
garment_description=description,
category=category,
use_auto_mask=True,
use_auto_crop=True,
denoise_steps=30,
seed=42,
num_images=1
)
# 5. Return result
return [{
"generated_image": image_to_base64(generated_images[0]),
"masked_image": image_to_base64(masked_image)
}]