| import base64
|
| import logging
|
| import os
|
| import random
|
| import sys
|
|
|
| import comfy.model_management
|
| import folder_paths
|
| import numpy as np
|
| import torch
|
| import trimesh
|
| from PIL import Image
|
| from trimesh.exchange import gltf
|
|
|
| sys.path.append(os.path.dirname(__file__))
|
| from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
|
| from spar3d.system import SPAR3D
|
| from spar3d.utils import foreground_crop
|
|
|
| SPAR3D_CATEGORY = "SPAR3D"
|
| SPAR3D_MODEL_NAME = "stabilityai/spar3d"
|
|
|
|
|
| class SPAR3DLoader:
|
| CATEGORY = SPAR3D_CATEGORY
|
| FUNCTION = "load"
|
| RETURN_NAMES = ("spar3d_model",)
|
| RETURN_TYPES = ("SPAR3D_MODEL",)
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "low_vram_mode": ("BOOLEAN", {"default": False}),
|
| }
|
| }
|
|
|
| def load(self, low_vram_mode=False):
|
| device = comfy.model_management.get_torch_device()
|
| model = SPAR3D.from_pretrained(
|
| SPAR3D_MODEL_NAME,
|
| config_name="config.yaml",
|
| weight_name="model.safetensors",
|
| low_vram_mode=low_vram_mode,
|
| )
|
| model.to(device)
|
| model.eval()
|
|
|
| return (model,)
|
|
|
|
|
| class SPAR3DPreview:
|
| CATEGORY = SPAR3D_CATEGORY
|
| FUNCTION = "preview"
|
| OUTPUT_NODE = True
|
| RETURN_TYPES = ()
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": {"mesh": ("MESH",)}}
|
|
|
| def preview(self, mesh):
|
| glbs = []
|
| for m in mesh:
|
| scene = trimesh.Scene(m)
|
| glb_data = gltf.export_glb(scene, include_normals=True)
|
| glb_base64 = base64.b64encode(glb_data).decode("utf-8")
|
| glbs.append(glb_base64)
|
| return {"ui": {"glbs": glbs}}
|
|
|
|
|
| class SPAR3DSampler:
|
| CATEGORY = SPAR3D_CATEGORY
|
| FUNCTION = "predict"
|
| RETURN_NAMES = ("mesh", "pointcloud")
|
| RETURN_TYPES = ("MESH", "POINTCLOUD")
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| remesh_choices = ["none"]
|
| if TRIANGLE_REMESH_AVAILABLE:
|
| remesh_choices.append("triangle")
|
| if QUAD_REMESH_AVAILABLE:
|
| remesh_choices.append("quad")
|
|
|
| opt_dict = {
|
| "mask": ("MASK",),
|
| "pointcloud": ("POINTCLOUD",),
|
| "target_type": (["none", "vertex", "face"],),
|
| "target_count": (
|
| "INT",
|
| {"default": 1000, "min": 3, "max": 20000, "step": 1},
|
| ),
|
| "guidance_scale": (
|
| "FLOAT",
|
| {"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05},
|
| ),
|
| "seed": (
|
| "INT",
|
| {"default": 42, "min": 0, "max": 2**32 - 1, "step": 1},
|
| ),
|
| }
|
| if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
|
| opt_dict["remesh"] = (remesh_choices,)
|
|
|
| return {
|
| "required": {
|
| "model": ("SPAR3D_MODEL",),
|
| "image": ("IMAGE",),
|
| "foreground_ratio": (
|
| "FLOAT",
|
| {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01},
|
| ),
|
| "texture_resolution": (
|
| "INT",
|
| {"default": 1024, "min": 512, "max": 2048, "step": 256},
|
| ),
|
| },
|
| "optional": opt_dict,
|
| }
|
|
|
| def predict(
|
| s,
|
| model,
|
| image,
|
| mask,
|
| foreground_ratio,
|
| texture_resolution,
|
| pointcloud=None,
|
| remesh="none",
|
| target_type="none",
|
| target_count=1000,
|
| guidance_scale=3.0,
|
| seed=42,
|
| ):
|
| if image.shape[0] != 1:
|
| raise ValueError("Only one image can be processed at a time")
|
|
|
| vertex_count = (
|
| -1
|
| if target_type == "none"
|
| else (target_count // 2 if target_type == "face" else target_count)
|
| )
|
|
|
| pil_image = Image.fromarray(
|
| torch.clamp(torch.round(255.0 * image[0]), 0, 255)
|
| .type(torch.uint8)
|
| .cpu()
|
| .numpy()
|
| )
|
|
|
| if mask is not None:
|
| print("Using Mask")
|
| mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
|
| np.uint8
|
| )
|
| mask_pil = Image.fromarray(mask_np, mode="L")
|
| pil_image.putalpha(mask_pil)
|
| else:
|
| if image.shape[3] != 4:
|
| print("No mask or alpha channel detected, Converting to RGBA")
|
| pil_image = pil_image.convert("RGBA")
|
|
|
| pil_image = foreground_crop(pil_image, foreground_ratio)
|
|
|
| model.cfg.guidance_scale = guidance_scale
|
| random.seed(seed)
|
| torch.manual_seed(seed)
|
| np.random.seed(seed)
|
|
|
| print(remesh)
|
| with torch.no_grad():
|
| with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle":
|
| raise ImportError(
|
| "Triangle remeshing requires gpytoolbox to be installed"
|
| )
|
| if not QUAD_REMESH_AVAILABLE and remesh == "quad":
|
| raise ImportError("Quad remeshing requires pynim to be installed")
|
| mesh, glob_dict = model.run_image(
|
| pil_image,
|
| bake_resolution=texture_resolution,
|
| pointcloud=pointcloud,
|
| remesh=remesh,
|
| vertex_count=vertex_count,
|
| )
|
|
|
| if mesh.vertices.shape[0] == 0:
|
| raise ValueError("No subject detected in the image")
|
|
|
| return (
|
| [mesh],
|
| glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(),
|
| )
|
|
|
|
|
| class SPAR3DSave:
|
| CATEGORY = SPAR3D_CATEGORY
|
| FUNCTION = "save"
|
| OUTPUT_NODE = True
|
| RETURN_TYPES = ()
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "mesh": ("MESH",),
|
| "filename_prefix": ("STRING", {"default": "SPAR3D"}),
|
| }
|
| }
|
|
|
| def __init__(self):
|
| self.type = "output"
|
|
|
| def save(self, mesh, filename_prefix):
|
| output_dir = folder_paths.get_output_directory()
|
| glbs = []
|
| for idx, m in enumerate(mesh):
|
| scene = trimesh.Scene(m)
|
| glb_data = gltf.export_glb(scene, include_normals=True)
|
| logging.info(f"Generated GLB model with {len(glb_data)} bytes")
|
|
|
| full_output_folder, filename, counter, subfolder, filename_prefix = (
|
| folder_paths.get_save_image_path(filename_prefix, output_dir)
|
| )
|
| filename = filename.replace("%batch_num%", str(idx))
|
| out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
|
| with open(out_path, "wb") as f:
|
| f.write(glb_data)
|
| glbs.append(base64.b64encode(glb_data).decode("utf-8"))
|
| return {"ui": {"glbs": glbs}}
|
|
|
|
|
| class SPAR3DPointCloudLoader:
|
| CATEGORY = SPAR3D_CATEGORY
|
| FUNCTION = "load_pointcloud"
|
| RETURN_TYPES = ("POINTCLOUD",)
|
| RETURN_NAMES = ("pointcloud",)
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "file": ("STRING", {"default": None}),
|
| }
|
| }
|
|
|
| def load_pointcloud(self, file):
|
| if file is None or file == "":
|
| return (None,)
|
|
|
| mesh = trimesh.load(file)
|
|
|
|
|
| vertices = mesh.vertices
|
|
|
|
|
| if mesh.visual.vertex_colors is not None:
|
| colors = (
|
| mesh.visual.vertex_colors[:, :3] / 255.0
|
| )
|
| else:
|
| colors = np.ones((len(vertices), 3))
|
|
|
|
|
| point_cloud = []
|
| for vertex, color in zip(vertices, colors):
|
| point_cloud.extend(
|
| [
|
| float(vertex[0]),
|
| float(vertex[1]),
|
| float(vertex[2]),
|
| float(color[0]),
|
| float(color[1]),
|
| float(color[2]),
|
| ]
|
| )
|
|
|
| return (point_cloud,)
|
|
|
|
|
| class SPAR3DPointCloudSaver:
|
| CATEGORY = SPAR3D_CATEGORY
|
| FUNCTION = "save_pointcloud"
|
| OUTPUT_NODE = True
|
| RETURN_TYPES = ()
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "pointcloud": ("POINTCLOUD",),
|
| "filename_prefix": ("STRING", {"default": "SPAR3D"}),
|
| }
|
| }
|
|
|
| def save_pointcloud(self, pointcloud, filename_prefix):
|
| if pointcloud is None:
|
| return {"ui": {"text": "No point cloud data to save"}}
|
|
|
|
|
| points = np.array(pointcloud).reshape(-1, 6)
|
|
|
|
|
| vertex_array = np.zeros(
|
| len(points),
|
| dtype=[
|
| ("x", "f4"),
|
| ("y", "f4"),
|
| ("z", "f4"),
|
| ("red", "u1"),
|
| ("green", "u1"),
|
| ("blue", "u1"),
|
| ],
|
| )
|
|
|
|
|
| vertex_array["x"] = points[:, 0]
|
| vertex_array["y"] = points[:, 1]
|
| vertex_array["z"] = points[:, 2]
|
|
|
| vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8)
|
| vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8)
|
| vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8)
|
|
|
|
|
| ply_data = trimesh.PointCloud(
|
| vertices=points[:, :3], colors=points[:, 3:] * 255
|
| )
|
|
|
|
|
| output_dir = folder_paths.get_output_directory()
|
| full_output_folder, filename, counter, subfolder, filename_prefix = (
|
| folder_paths.get_save_image_path(filename_prefix, output_dir)
|
| )
|
| out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply")
|
|
|
| ply_data.export(out_path)
|
|
|
| return {"ui": {"text": f"Saved point cloud to {out_path}"}}
|
|
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "SPAR3DLoader": "SPAR3D Loader",
|
| "SPAR3DPreview": "SPAR3D Preview",
|
| "SPAR3DSampler": "SPAR3D Sampler",
|
| "SPAR3DSave": "SPAR3D Save",
|
| "SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader",
|
| "SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver",
|
| }
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "SPAR3DLoader": SPAR3DLoader,
|
| "SPAR3DPreview": SPAR3DPreview,
|
| "SPAR3DSampler": SPAR3DSampler,
|
| "SPAR3DSave": SPAR3DSave,
|
| "SPAR3DPointCloudLoader": SPAR3DPointCloudLoader,
|
| "SPAR3DPointCloudSaver": SPAR3DPointCloudSaver,
|
| }
|
|
|
| WEB_DIRECTORY = "./comfyui"
|
|
|
| __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
|
|
|