Trouter-Library commited on
Commit
b4d6913
·
verified ·
1 Parent(s): 6ea7c60

Create inference/generate_shards.py

Browse files
Files changed (1) hide show
  1. inference/generate_shards.py +327 -0
inference/generate_shards.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shard Generator for Helion-OSC
3
+ Creates placeholder or actual safetensors shard files
4
+
5
+ This script helps you:
6
+ 1. Generate placeholder shards for testing
7
+ 2. Split a large model into 116 shards
8
+ 3. Verify shard integrity
9
+ """
10
+
11
+ import torch
12
+ import json
13
+ import os
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional
16
+ import logging
17
+ from tqdm import tqdm
18
+ from safetensors.torch import save_file, load_file
19
+ import numpy as np
20
+
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class ShardGenerator:
26
+ """Generate and manage model shards"""
27
+
28
+ def __init__(self, output_dir: str, total_shards: int = 116):
29
+ """
30
+ Initialize shard generator
31
+
32
+ Args:
33
+ output_dir: Directory to save shards
34
+ total_shards: Total number of shards to generate
35
+ """
36
+ self.output_dir = Path(output_dir)
37
+ self.total_shards = total_shards
38
+ self.output_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ logger.info(f"Shard generator initialized")
41
+ logger.info(f"Output directory: {self.output_dir}")
42
+ logger.info(f"Total shards: {self.total_shards}")
43
+
44
+ def get_shard_name(self, shard_idx: int) -> str:
45
+ """Get formatted shard name"""
46
+ return f"model-{shard_idx:05d}-of-{self.total_shards:05d}.safetensors"
47
+
48
+ def generate_placeholder_shards(
49
+ self,
50
+ shard_size_mb: float = 2800,
51
+ tensor_dtype: torch.dtype = torch.bfloat16
52
+ ):
53
+ """
54
+ Generate placeholder shards for testing
55
+
56
+ Args:
57
+ shard_size_mb: Target size per shard in MB
58
+ tensor_dtype: Data type for tensors
59
+ """
60
+ logger.info("Generating placeholder shards...")
61
+ logger.info(f"Target shard size: {shard_size_mb} MB")
62
+
63
+ # Calculate tensor size to achieve target shard size
64
+ # bfloat16 = 2 bytes per element
65
+ bytes_per_element = 2 if tensor_dtype == torch.bfloat16 else 4
66
+ target_bytes = shard_size_mb * 1024 * 1024
67
+ num_elements = int(target_bytes / bytes_per_element)
68
+
69
+ # Create tensors in reasonable shapes
70
+ # For a transformer layer, we might have multiple weight matrices
71
+ tensor_shapes = self._generate_realistic_shapes(num_elements)
72
+
73
+ for shard_idx in tqdm(range(1, self.total_shards + 1), desc="Creating shards"):
74
+ shard_name = self.get_shard_name(shard_idx)
75
+ shard_path = self.output_dir / shard_name
76
+
77
+ # Generate random tensors for this shard
78
+ tensors = {}
79
+ for name, shape in tensor_shapes.items():
80
+ key = f"layer_{shard_idx}.{name}"
81
+ tensors[key] = torch.randn(shape, dtype=tensor_dtype)
82
+
83
+ # Save as safetensors
84
+ save_file(tensors, str(shard_path))
85
+
86
+ # Verify size
87
+ actual_size_mb = shard_path.stat().st_size / (1024 * 1024)
88
+ logger.debug(f"{shard_name}: {actual_size_mb:.2f} MB")
89
+
90
+ logger.info(f"✓ Generated {self.total_shards} placeholder shards")
91
+
92
+ def _generate_realistic_shapes(self, total_elements: int) -> Dict[str, tuple]:
93
+ """
94
+ Generate realistic tensor shapes for a transformer layer
95
+
96
+ Args:
97
+ total_elements: Total number of elements to distribute
98
+
99
+ Returns:
100
+ Dictionary of tensor names and shapes
101
+ """
102
+ # Typical transformer layer weights
103
+ hidden_size = 8192
104
+ intermediate_size = 28672
105
+ num_heads = 64
106
+ head_dim = 128
107
+
108
+ shapes = {
109
+ "self_attn.q_proj.weight": (hidden_size, hidden_size),
110
+ "self_attn.k_proj.weight": (hidden_size // 8, hidden_size), # KV heads
111
+ "self_attn.v_proj.weight": (hidden_size // 8, hidden_size),
112
+ "self_attn.o_proj.weight": (hidden_size, hidden_size),
113
+ "mlp.gate_proj.weight": (intermediate_size, hidden_size),
114
+ "mlp.up_proj.weight": (intermediate_size, hidden_size),
115
+ "mlp.down_proj.weight": (hidden_size, intermediate_size),
116
+ "input_layernorm.weight": (hidden_size,),
117
+ "post_attention_layernorm.weight": (hidden_size,),
118
+ }
119
+
120
+ return shapes
121
+
122
+ def split_large_model(
123
+ self,
124
+ model_state_dict: Dict[str, torch.Tensor],
125
+ max_shard_size_gb: float = 2.8
126
+ ):
127
+ """
128
+ Split a large model into shards
129
+
130
+ Args:
131
+ model_state_dict: Model weights dictionary
132
+ max_shard_size_gb: Maximum size per shard in GB
133
+ """
134
+ logger.info("Splitting model into shards...")
135
+
136
+ max_shard_bytes = max_shard_size_gb * 1024 ** 3
137
+
138
+ current_shard = {}
139
+ current_size = 0
140
+ shard_idx = 1
141
+ weight_map = {}
142
+
143
+ for name, tensor in tqdm(model_state_dict.items(), desc="Processing weights"):
144
+ # Calculate tensor size
145
+ tensor_bytes = tensor.nelement() * tensor.element_size()
146
+
147
+ # Check if adding this tensor exceeds shard size
148
+ if current_size + tensor_bytes > max_shard_bytes and current_shard:
149
+ # Save current shard
150
+ shard_name = self.get_shard_name(shard_idx)
151
+ self._save_shard(current_shard, shard_name)
152
+
153
+ # Update weight map
154
+ for weight_name in current_shard.keys():
155
+ weight_map[weight_name] = shard_name
156
+
157
+ # Reset for next shard
158
+ current_shard = {}
159
+ current_size = 0
160
+ shard_idx += 1
161
+
162
+ # Add tensor to current shard
163
+ current_shard[name] = tensor
164
+ current_size += tensor_bytes
165
+
166
+ # Save final shard
167
+ if current_shard:
168
+ shard_name = self.get_shard_name(shard_idx)
169
+ self._save_shard(current_shard, shard_name)
170
+
171
+ for weight_name in current_shard.keys():
172
+ weight_map[weight_name] = shard_name
173
+
174
+ logger.info(f"✓ Model split into {shard_idx} shards")
175
+
176
+ # Save weight map index
177
+ self._save_index(weight_map, shard_idx)
178
+
179
+ return weight_map
180
+
181
+ def _save_shard(self, tensors: Dict[str, torch.Tensor], shard_name: str):
182
+ """Save a shard file"""
183
+ shard_path = self.output_dir / shard_name
184
+ save_file(tensors, str(shard_path))
185
+ size_mb = shard_path.stat().st_size / (1024 * 1024)
186
+ logger.info(f"Saved {shard_name} ({size_mb:.2f} MB)")
187
+
188
+ def _save_index(self, weight_map: Dict[str, str], total_shards: int):
189
+ """Save the weight map index file"""
190
+ index = {
191
+ "metadata": {
192
+ "total_size": sum(
193
+ (self.output_dir / shard).stat().st_size
194
+ for shard in set(weight_map.values())
195
+ ),
196
+ "total_shards": total_shards,
197
+ "format": "safetensors",
198
+ "model_type": "helion-osc"
199
+ },
200
+ "weight_map": weight_map
201
+ }
202
+
203
+ index_path = self.output_dir / "model.safetensors.index.json"
204
+ with open(index_path, 'w') as f:
205
+ json.dump(index, f, indent=2)
206
+
207
+ logger.info(f"Saved index to {index_path}")
208
+
209
+ def verify_shards(self) -> bool:
210
+ """Verify all shards can be loaded"""
211
+ logger.info("Verifying shards...")
212
+
213
+ all_valid = True
214
+
215
+ for shard_idx in tqdm(range(1, self.total_shards + 1), desc="Verifying"):
216
+ shard_name = self.get_shard_name(shard_idx)
217
+ shard_path = self.output_dir / shard_name
218
+
219
+ if not shard_path.exists():
220
+ logger.error(f"Missing: {shard_name}")
221
+ all_valid = False
222
+ continue
223
+
224
+ try:
225
+ # Try to load the shard
226
+ _ = load_file(str(shard_path))
227
+ except Exception as e:
228
+ logger.error(f"Invalid {shard_name}: {e}")
229
+ all_valid = False
230
+
231
+ if all_valid:
232
+ logger.info("✓ All shards verified successfully")
233
+ else:
234
+ logger.error("✗ Some shards are missing or invalid")
235
+
236
+ return all_valid
237
+
238
+ def get_shard_stats(self) -> Dict:
239
+ """Get statistics about shards"""
240
+ stats = {
241
+ "total_shards": self.total_shards,
242
+ "present_shards": 0,
243
+ "total_size_gb": 0,
244
+ "sizes_mb": []
245
+ }
246
+
247
+ for shard_idx in range(1, self.total_shards + 1):
248
+ shard_name = self.get_shard_name(shard_idx)
249
+ shard_path = self.output_dir / shard_name
250
+
251
+ if shard_path.exists():
252
+ stats["present_shards"] += 1
253
+ size_mb = shard_path.stat().st_size / (1024 * 1024)
254
+ stats["sizes_mb"].append(size_mb)
255
+ stats["total_size_gb"] += size_mb / 1024
256
+
257
+ if stats["sizes_mb"]:
258
+ stats["avg_size_mb"] = np.mean(stats["sizes_mb"])
259
+ stats["min_size_mb"] = np.min(stats["sizes_mb"])
260
+ stats["max_size_mb"] = np.max(stats["sizes_mb"])
261
+
262
+ return stats
263
+
264
+
265
+ def main():
266
+ """CLI interface"""
267
+ import argparse
268
+
269
+ parser = argparse.ArgumentParser(description="Helion-OSC Shard Generator")
270
+ parser.add_argument(
271
+ "output_dir",
272
+ type=str,
273
+ help="Output directory for shards"
274
+ )
275
+ parser.add_argument(
276
+ "--action",
277
+ choices=["generate", "verify", "stats"],
278
+ default="generate",
279
+ help="Action to perform"
280
+ )
281
+ parser.add_argument(
282
+ "--total-shards",
283
+ type=int,
284
+ default=116,
285
+ help="Total number of shards"
286
+ )
287
+ parser.add_argument(
288
+ "--shard-size",
289
+ type=float,
290
+ default=2800,
291
+ help="Target shard size in MB"
292
+ )
293
+
294
+ args = parser.parse_args()
295
+
296
+ generator = ShardGenerator(
297
+ output_dir=args.output_dir,
298
+ total_shards=args.total_shards
299
+ )
300
+
301
+ if args.action == "generate":
302
+ logger.info("Generating placeholder shards for testing...")
303
+ logger.warning("Note: These are random tensors for testing only!")
304
+ generator.generate_placeholder_shards(shard_size_mb=args.shard_size)
305
+
306
+ elif args.action == "verify":
307
+ generator.verify_shards()
308
+
309
+ elif args.action == "stats":
310
+ stats = generator.get_shard_stats()
311
+ print("\n" + "="*80)
312
+ print("SHARD STATISTICS")
313
+ print("="*80)
314
+ print(f"Total Shards: {stats['total_shards']}")
315
+ print(f"Present Shards: {stats['present_shards']}")
316
+ print(f"Total Size: {stats['total_size_gb']:.2f} GB")
317
+
318
+ if stats['present_shards'] > 0:
319
+ print(f"Average Size: {stats['avg_size_mb']:.2f} MB")
320
+ print(f"Min Size: {stats['min_size_mb']:.2f} MB")
321
+ print(f"Max Size: {stats['max_size_mb']:.2f} MB")
322
+
323
+ print("="*80)
324
+
325
+
326
+ if __name__ == "__main__":
327
+ main()