AlekseyCalvin commited on
Commit
744516f
·
verified ·
1 Parent(s): cf1b750

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +561 -0
app.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import re
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional
13
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
14
+ from safetensors.torch import load_file, save_file
15
+ from tqdm import tqdm
16
+
17
+ # --- Memory Efficient Safetensors ---
18
+ class MemoryEfficientSafeOpen:
19
+ """
20
+ Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
21
+ Essential for running on limited hardware.
22
+ """
23
+ def __init__(self, filename):
24
+ self.filename = filename
25
+ self.file = open(filename, "rb")
26
+ self.header, self.header_size = self._read_header()
27
+
28
+ def __enter__(self):
29
+ return self
30
+
31
+ def __exit__(self, exc_type, exc_val, exc_tb):
32
+ self.file.close()
33
+
34
+ def keys(self) -> list[str]:
35
+ return [k for k in self.header.keys() if k != "__metadata__"]
36
+
37
+ def metadata(self) -> Dict[str, str]:
38
+ return self.header.get("__metadata__", {})
39
+
40
+ def get_tensor(self, key):
41
+ if key not in self.header:
42
+ raise KeyError(f"Tensor '{key}' not found in the file")
43
+ metadata = self.header[key]
44
+ offset_start, offset_end = metadata["data_offsets"]
45
+ self.file.seek(self.header_size + 8 + offset_start)
46
+ tensor_bytes = self.file.read(offset_end - offset_start)
47
+ return self._deserialize_tensor(tensor_bytes, metadata)
48
+
49
+ def _read_header(self):
50
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
51
+ header_json = self.file.read(header_size).decode("utf-8")
52
+ return json.loads(header_json), header_size
53
+
54
+ def _deserialize_tensor(self, tensor_bytes, metadata):
55
+ dtype_map = {
56
+ "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
57
+ "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
58
+ "U8": torch.uint8, "BOOL": torch.bool
59
+ }
60
+ dtype = dtype_map[metadata["dtype"]]
61
+ shape = metadata["shape"]
62
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
63
+
64
+ # --- Constants & Setup ---
65
+ TempDir = Path("./temp_tool")
66
+ os.makedirs(TempDir, exist_ok=True)
67
+ api = HfApi()
68
+
69
+ def cleanup_temp():
70
+ if TempDir.exists():
71
+ shutil.rmtree(TempDir)
72
+ os.makedirs(TempDir, exist_ok=True)
73
+ gc.collect()
74
+
75
+ def verify_safetensors(path):
76
+ """Checks if a file is a valid safetensors file."""
77
+ try:
78
+ with open(path, "rb") as f:
79
+ header_size_bytes = f.read(8)
80
+ if len(header_size_bytes) != 8: return False
81
+ header_size = struct.unpack("<Q", header_size_bytes)[0]
82
+ if header_size > os.path.getsize(path) or header_size <= 0:
83
+ return False
84
+ return True
85
+ except:
86
+ return False
87
+
88
+ def download_file(input_path, token, filename=None):
89
+ """Downloads a file from URL or HF Repo."""
90
+ local_path = TempDir / (filename if filename else "model.safetensors")
91
+
92
+ if input_path.startswith("http"):
93
+ print(f"Downloading from URL: {input_path}")
94
+ try:
95
+ response = requests.get(input_path, stream=True, timeout=30)
96
+ response.raise_for_status()
97
+ with open(local_path, 'wb') as f:
98
+ for chunk in response.iter_content(chunk_size=8192):
99
+ f.write(chunk)
100
+ except Exception as e:
101
+ raise ValueError(f"Failed to download URL. Check your link. Error: {e}")
102
+ else:
103
+ print(f"Downloading from Repo: {input_path}")
104
+ if not filename:
105
+ try:
106
+ files = list_repo_files(repo_id=input_path, token=token)
107
+ safetensors = [f for f in files if f.endswith(".safetensors")]
108
+ if safetensors:
109
+ filename = safetensors[0]
110
+ for f in safetensors:
111
+ if "adapter" in f: filename = f
112
+ else:
113
+ filename = "adapter_model.bin"
114
+ except Exception as e:
115
+ filename = "adapter_model.safetensors"
116
+
117
+ try:
118
+ hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
119
+ downloaded_path = TempDir / filename
120
+ if downloaded_path != local_path:
121
+ if local_path.exists(): os.remove(local_path)
122
+ shutil.move(downloaded_path, local_path)
123
+ except Exception as e:
124
+ raise ValueError(f"Failed to download from HF Repo. Check ID/Token. Error: {e}")
125
+
126
+ if not verify_safetensors(local_path):
127
+ raise ValueError(f"Downloaded file is NOT a valid safetensors file. Check your URL/Repo. (File size: {os.path.getsize(local_path)} bytes)")
128
+
129
+ return local_path
130
+
131
+ def get_key_stem(key):
132
+ key = key.replace(".weight", "").replace(".bias", "")
133
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
134
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
135
+ key = key.replace(".alpha", "")
136
+
137
+ prefixes = [
138
+ "model.diffusion_model.", "diffusion_model.", "model.",
139
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_",
140
+ "base_model.model."
141
+ ]
142
+
143
+ changed = True
144
+ while changed:
145
+ changed = False
146
+ for p in prefixes:
147
+ if key.startswith(p):
148
+ key = key[len(p):]
149
+ changed = True
150
+ return key
151
+
152
+ # =================================================================================
153
+ # TAB 1: UNIVERSAL MERGE (Low-Precision Optimized)
154
+ # =================================================================================
155
+
156
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
157
+ print(f"Loading LoRA from {lora_path} in {precision_dtype}...")
158
+ state_dict = load_file(lora_path, device="cpu")
159
+
160
+ pairs = {}
161
+ alphas = {}
162
+
163
+ for k, v in state_dict.items():
164
+ stem = get_key_stem(k)
165
+ if "alpha" in k:
166
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
167
+ else:
168
+ if stem not in pairs:
169
+ pairs[stem] = {}
170
+
171
+ # Cast immediately to save RAM
172
+ tensor_low = v.to(dtype=precision_dtype)
173
+
174
+ if "lora_down" in k or "lora_A" in k:
175
+ pairs[stem]["down"] = tensor_low
176
+ pairs[stem]["rank"] = v.shape[0]
177
+ elif "lora_up" in k or "lora_B" in k:
178
+ pairs[stem]["up"] = tensor_low
179
+
180
+ for stem in pairs:
181
+ if stem in alphas:
182
+ pairs[stem]["alpha"] = alphas[stem]
183
+ else:
184
+ if "rank" in pairs[stem]:
185
+ pairs[stem]["alpha"] = float(pairs[stem]["rank"])
186
+ else:
187
+ pairs[stem]["alpha"] = 1.0
188
+
189
+ return pairs
190
+
191
+ def merge_shard_logic(base_path, lora_pairs, scale, output_path, precision_dtype=torch.bfloat16):
192
+ print(f"Loading base shard: {base_path}")
193
+ base_state = load_file(base_path, device="cpu")
194
+
195
+ lora_keys = set(lora_pairs.keys())
196
+ keys_to_process = list(base_state.keys())
197
+
198
+ for k in keys_to_process:
199
+ v = base_state[k]
200
+ base_stem = get_key_stem(k)
201
+ match = None
202
+
203
+ # 1. Exact Match
204
+ if base_stem in lora_keys:
205
+ match = lora_pairs[base_stem]
206
+ else:
207
+ # 2. Heuristic Match
208
+ if "to_q" in base_stem:
209
+ qkv_stem = base_stem.replace("to_q", "qkv")
210
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
211
+ elif "to_k" in base_stem:
212
+ qkv_stem = base_stem.replace("to_k", "qkv")
213
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
214
+ elif "to_v" in base_stem:
215
+ qkv_stem = base_stem.replace("to_v", "qkv")
216
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
217
+
218
+ if match and "down" in match and "up" in match:
219
+ down = match["down"]
220
+ up = match["up"]
221
+ alpha = match["alpha"]
222
+ rank = match["rank"]
223
+
224
+ scaling = scale * (alpha / rank)
225
+
226
+ # Handle Conv 1x1 squeeze
227
+ if len(v.shape) == 4 and len(down.shape) == 2:
228
+ down = down.unsqueeze(-1).unsqueeze(-1)
229
+ up = up.unsqueeze(-1).unsqueeze(-1)
230
+
231
+ try:
232
+ if len(up.shape) == 4:
233
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
234
+ else:
235
+ delta = up @ down
236
+ except:
237
+ delta = up.T @ down
238
+
239
+ delta = delta * scaling
240
+
241
+ valid_delta = True
242
+
243
+ # --- Dynamic Reshaping / Slicing ---
244
+ if delta.shape == v.shape:
245
+ pass
246
+ elif delta.shape[0] == v.shape[0] * 3:
247
+ chunk_size = v.shape[0]
248
+ if "to_q" in k:
249
+ delta = delta[0:chunk_size, ...]
250
+ elif "to_k" in k:
251
+ delta = delta[chunk_size:2*chunk_size, ...]
252
+ elif "to_v" in k:
253
+ delta = delta[2*chunk_size:, ...]
254
+ else:
255
+ valid_delta = False
256
+ elif delta.numel() == v.numel():
257
+ delta = delta.reshape(v.shape)
258
+ else:
259
+ # print(f"Skipping {k}: Mismatch. Base: {v.shape}, Delta: {delta.shape}")
260
+ valid_delta = False
261
+
262
+ if valid_delta:
263
+ # Optimized In-Place Addition (Zero Copy)
264
+ if v.dtype != delta.dtype:
265
+ delta = delta.to(v.dtype)
266
+
267
+ v.add_(delta)
268
+ del delta
269
+
270
+ if len(keys_to_process) > 100 and keys_to_process.index(k) % 50 == 0:
271
+ gc.collect()
272
+
273
+ save_file(base_state, output_path)
274
+ return True
275
+
276
+ # NOTE: Arguments must match exactly with the inputs=[] list in click()
277
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, output_repo, structure_repo, private, progress=gr.Progress()):
278
+ cleanup_temp()
279
+ login(hf_token)
280
+
281
+ # Determine Dtype
282
+ if precision == "bf16":
283
+ dtype = torch.bfloat16
284
+ elif precision == "fp16":
285
+ dtype = torch.float16
286
+ else:
287
+ dtype = torch.float32
288
+
289
+ print(f"Selected Precision: {dtype}")
290
+
291
+ try:
292
+ api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
293
+ except Exception as e:
294
+ return f"Error creating repo: {e}"
295
+
296
+ if structure_repo:
297
+ print("Cloning structure...")
298
+ try:
299
+ files = list_repo_files(repo_id=structure_repo, token=hf_token)
300
+ for f in files:
301
+ if not f.endswith(".safetensors") and not f.endswith(".bin"):
302
+ try:
303
+ path = hf_hub_download(repo_id=structure_repo, filename=f, token=hf_token)
304
+ api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=output_repo, token=hf_token)
305
+ except: pass
306
+ except Exception as e:
307
+ print(f"Structure clone warning: {e}")
308
+
309
+ try:
310
+ progress(0.1, desc="Downloading LoRA...")
311
+ lora_path = download_file(lora_input, hf_token)
312
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
313
+ except Exception as e:
314
+ return f"CRITICAL ERROR: {str(e)}"
315
+
316
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
317
+ shards = [f for f in files if f.endswith(".safetensors")]
318
+ if base_subfolder:
319
+ shards = [f for f in shards if f.startswith(base_subfolder)]
320
+
321
+ if not shards: return "Error: No safetensors found in base."
322
+
323
+ for i, shard in enumerate(shards):
324
+ progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
325
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
326
+ merged_path = TempDir / "merged.safetensors"
327
+
328
+ # Merge
329
+ merge_shard_logic(local_shard, lora_pairs, scale, merged_path, precision_dtype=dtype)
330
+
331
+ # Upload
332
+ api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)
333
+
334
+ os.remove(local_shard)
335
+ if merged_path.exists(): os.remove(merged_path)
336
+ gc.collect()
337
+
338
+ return f"Done! Model at https://huggingface.co/{output_repo}"
339
+
340
+ # =================================================================================
341
+ # TAB 2: EXTRACT LORA
342
+ # =================================================================================
343
+
344
+ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
345
+ org = MemoryEfficientSafeOpen(model_org)
346
+ tuned = MemoryEfficientSafeOpen(model_tuned)
347
+ lora_sd = {}
348
+
349
+ print("Calculating diffs and running SVD (Layer-wise)...")
350
+ keys = list(org.keys())
351
+
352
+ for key in tqdm(keys):
353
+ if key not in tuned.keys(): continue
354
+ mat_org = org.get_tensor(key).float()
355
+ mat_tuned = tuned.get_tensor(key).float()
356
+
357
+ diff = mat_tuned - mat_org
358
+ if torch.max(torch.abs(diff)) < 1e-4: continue
359
+
360
+ out_dim, in_dim = diff.shape[:2]
361
+ r = min(rank, in_dim, out_dim)
362
+ is_conv = len(diff.shape) == 4
363
+ if is_conv: diff = diff.flatten(start_dim=1)
364
+
365
+ try:
366
+ U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
367
+ U = U[:, :r]
368
+ S = S[:r]
369
+ U = U @ torch.diag(S)
370
+ Vh = Vh[:r, :]
371
+
372
+ dist = torch.cat([U.flatten(), Vh.flatten()])
373
+ hi_val = torch.quantile(dist, clamp)
374
+ U = U.clamp(-hi_val, hi_val)
375
+ Vh = Vh.clamp(-hi_val, hi_val)
376
+
377
+ if is_conv:
378
+ U = U.reshape(out_dim, r, 1, 1)
379
+ Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
380
+ else:
381
+ U = U.reshape(out_dim, r)
382
+ Vh = Vh.reshape(r, in_dim)
383
+
384
+ stem = key.replace(".weight", "")
385
+ lora_sd[f"{stem}.lora_up.weight"] = U
386
+ lora_sd[f"{stem}.lora_down.weight"] = Vh
387
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
388
+ except Exception as e:
389
+ print(f"SVD failed for {key}: {e}")
390
+
391
+ out_path = TempDir / "extracted_lora.safetensors"
392
+ save_file(lora_sd, out_path)
393
+ return str(out_path)
394
+
395
+ def task_extract(hf_token, org_repo, tuned_repo, rank, output_repo):
396
+ cleanup_temp()
397
+ login(hf_token)
398
+ print("Downloading models...")
399
+ try:
400
+ p1 = download_file(org_repo, hf_token, "org.safetensors")
401
+ p2 = download_file(tuned_repo, hf_token, "tuned.safetensors")
402
+ out = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
403
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
404
+ api.upload_file(path_or_fileobj=out, path_in_repo="extracted_lora.safetensors", repo_id=output_repo, token=hf_token)
405
+ return "Extraction Done."
406
+ except Exception as e:
407
+ return f"Error: {e}"
408
+
409
+ # =================================================================================
410
+ # TAB 3: MERGE ADAPTERS (EMA)
411
+ # =================================================================================
412
+
413
+ def task_merge_adapters(hf_token, lora_urls, beta, output_repo):
414
+ cleanup_temp()
415
+ login(hf_token)
416
+ urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
417
+ paths = []
418
+ try:
419
+ for i, url in enumerate(urls):
420
+ paths.append(download_file(url, hf_token, f"adapter_{i}.safetensors"))
421
+ except Exception as e:
422
+ return f"Download Error: {e}"
423
+
424
+ if not paths: return "No models found"
425
+
426
+ base_sd = load_file(paths[0], device="cpu")
427
+ for k in base_sd:
428
+ if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
429
+
430
+ for i, path in enumerate(paths[1:]):
431
+ print(f"Merging {path}")
432
+ curr = load_file(path, device="cpu")
433
+ for k in base_sd:
434
+ if k in curr and "alpha" not in k:
435
+ base_sd[k] = base_sd[k] * beta + curr[k].float() * (1 - beta)
436
+
437
+ out = TempDir / "merged_adapters.safetensors"
438
+ save_file(base_sd, out)
439
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
440
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=output_repo, token=hf_token)
441
+ return "Done"
442
+
443
+ # =================================================================================
444
+ # TAB 4: RESIZE
445
+ # =================================================================================
446
+
447
+ def task_resize(hf_token, lora_input, new_rank, output_repo):
448
+ cleanup_temp()
449
+ login(hf_token)
450
+ try:
451
+ path = download_file(lora_input, hf_token)
452
+ except Exception as e:
453
+ return f"Download Error: {e}"
454
+
455
+ state = load_file(path, device="cpu")
456
+ new_state = {}
457
+ print("Resizing...")
458
+
459
+ groups = {}
460
+ for k in state:
461
+ stem = get_key_stem(k)
462
+ stem_simple = k.split(".lora_")[0]
463
+ if stem_simple not in groups: groups[stem_simple] = {}
464
+ if "lora_down" in k or "lora_A" in k: groups[stem_simple]["down"] = state[k]
465
+ if "lora_up" in k or "lora_B" in k: groups[stem_simple]["up"] = state[k]
466
+
467
+ for stem, g in tqdm(groups.items()):
468
+ if "down" in g and "up" in g:
469
+ down, up = g["down"].float(), g["up"].float()
470
+ if len(down.shape) == 4:
471
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
472
+ flat = merged.flatten(1)
473
+ else:
474
+ merged = up @ down
475
+ flat = merged
476
+
477
+ U, S, Vh = torch.linalg.svd(flat, full_matrices=False)
478
+ U = U[:, :new_rank]
479
+ S = S[:new_rank]
480
+ U = U @ torch.diag(S)
481
+ Vh = Vh[:new_rank, :]
482
+
483
+ if len(down.shape) == 4:
484
+ U = U.reshape(up.shape[0], new_rank, 1, 1)
485
+ Vh = Vh.reshape(new_rank, down.shape[1], down.shape[2], down.shape[3])
486
+
487
+ new_state[f"{stem}.lora_down.weight"] = Vh
488
+ new_state[f"{stem}.lora_up.weight"] = U
489
+ new_state[f"{stem}.alpha"] = torch.tensor(new_rank).float()
490
+
491
+ out = TempDir / "resized.safetensors"
492
+ save_file(new_state, out)
493
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
494
+ api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=output_repo, token=hf_token)
495
+ return "Done"
496
+
497
+ # =================================================================================
498
+ # UI Construction
499
+ # =================================================================================
500
+
501
+ css = ".container { max-width: 900px; margin: auto; }"
502
+
503
+ with gr.Blocks() as demo:
504
+ gr.Markdown("# 🧰 SOONmerge® LoRA Toolkit")
505
+
506
+ with gr.Tabs():
507
+ with gr.Tab("Merge (Z-Image Fix)"):
508
+ t1_token = gr.Textbox(label="Token", type="password")
509
+ t1_base = gr.Textbox(label="Base Repo", value="ostris/Z-Image-De-Turbo")
510
+ t1_sub = gr.Textbox(label="Subfolder", value="transformer")
511
+ t1_lora = gr.Textbox(label="LoRA")
512
+
513
+ with gr.Row():
514
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=-1, maximum=2)
515
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], label="Precision", value="bf16")
516
+
517
+ t1_out = gr.Textbox(label="Output")
518
+ t1_struct = gr.Textbox(label="Structure Repo", value="Tongyi-MAI/Z-Image-Turbo")
519
+ # Explicitly defined checkbox to ensure correct arg count
520
+ t1_private = gr.Checkbox(label="Private Repo", value=True)
521
+
522
+ t1_btn = gr.Button("Merge")
523
+ t1_res = gr.Textbox(label="Result")
524
+
525
+ # Corrected argument count: exactly 9 inputs + 1 output
526
+ t1_btn.click(
527
+ task_merge,
528
+ inputs=[t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_out, t1_struct, t1_private],
529
+ outputs=t1_res
530
+ )
531
+
532
+ with gr.Tab("Extract"):
533
+ t2_token = gr.Textbox(label="Token", type="password")
534
+ t2_org = gr.Textbox(label="Original")
535
+ t2_tun = gr.Textbox(label="Tuned")
536
+ t2_rank = gr.Number(label="Rank", value=32)
537
+ t2_out = gr.Textbox(label="Output")
538
+ t2_btn = gr.Button("Extract")
539
+ t2_res = gr.Textbox(label="Result")
540
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
541
+
542
+ with gr.Tab("Merge Adapters"):
543
+ t3_token = gr.Textbox(label="Token", type="password")
544
+ t3_urls = gr.Textbox(label="URLs (comma sep)")
545
+ t3_beta = gr.Slider(label="Beta", value=0.9)
546
+ t3_out = gr.Textbox(label="Output")
547
+ t3_btn = gr.Button("Merge")
548
+ t3_res = gr.Textbox(label="Result")
549
+ t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_out], t3_res)
550
+
551
+ with gr.Tab("Resize"):
552
+ t4_token = gr.Textbox(label="Token", type="password")
553
+ t4_in = gr.Textbox(label="LoRA")
554
+ t4_rank = gr.Number(label="Rank", value=8)
555
+ t4_out = gr.Textbox(label="Output")
556
+ t4_btn = gr.Button("Resize")
557
+ t4_res = gr.Textbox(label="Result")
558
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_out], t4_res)
559
+
560
+ if __name__ == "__main__":
561
+ demo.queue().launch(css=css, ssr_mode=False)