|
|
|
|
|
|
|
|
import os |
|
|
import math |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from datetime import timedelta |
|
|
from multiprocessing.shared_memory import SharedMemory |
|
|
from uuid import uuid4 |
|
|
import numpy as np |
|
|
import time |
|
|
import json |
|
|
try: |
|
|
from hf3fs_fuse.io import make_iovec, make_ioring, ioring, register_fd, deregister_fd, h3fio |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
INT_LEN = 8 |
|
|
BYTE_ORDER = 'big' |
|
|
|
|
|
|
|
|
def tensor_to_bytes(tensor: torch.Tensor) -> bytes: |
|
|
if tensor.numel() == 0: |
|
|
return b'' |
|
|
return tensor.view(torch.int8).numpy().data.cast('B') |
|
|
|
|
|
|
|
|
except_fs = {'cpu'} |
|
|
clusters = ['jd', 'hg'] |
|
|
hf3fs_paths = [] |
|
|
hf3fs_mount_points = [] |
|
|
for cluster in clusters: |
|
|
hf3fs_paths += os.listdir(f'/hf3fs-{cluster}') if os.path.exists(f'/hf3fs-{cluster}') else [] |
|
|
hf3fs_mount_points += [os.path.join(f'/hf3fs-{cluster}', f) for f in hf3fs_paths if f not in except_fs] |
|
|
|
|
|
|
|
|
def get_hf3fs_mount_point(file_path: str) -> str: |
|
|
rp = os.path.realpath(Path(file_path).absolute()) |
|
|
return '/'.join(rp.split('/')[:3]) |
|
|
|
|
|
class DistWriter(): |
|
|
def __init__(self, max_ops=100<<10, write_buf_size=1<<29): |
|
|
self.max_ops = max_ops |
|
|
self.write_buf_size = write_buf_size |
|
|
self.shm = SharedMemory(name=f'hf3fs-iovs-{uuid4()}', create=True, size=self.write_buf_size) |
|
|
self._iov = {} |
|
|
self._buf = {} |
|
|
self._ior = {} |
|
|
for hf3fs_mount_point in hf3fs_mount_points: |
|
|
try: |
|
|
iov = make_iovec(self.shm, hf3fs_mount_point, block_size=0, numa=-1) |
|
|
buf = memoryview(iov.iov) |
|
|
ior = make_ioring(hf3fs_mount_point, 100 << 10, for_read=False, io_depth=-1, numa=-1) |
|
|
self._iov[hf3fs_mount_point] = iov |
|
|
self._buf[hf3fs_mount_point] = buf |
|
|
self._ior[hf3fs_mount_point] = ior |
|
|
except Exception: |
|
|
pass |
|
|
self.shm.unlink() |
|
|
self.fd_cache = {} |
|
|
|
|
|
def _open(self, file_path): |
|
|
if self.fd_cache.get(file_path) is None: |
|
|
|
|
|
hf3fs_mount_point = get_hf3fs_mount_point(file_path) |
|
|
try: |
|
|
fd = os.open(file_path, os.O_WRONLY | os.O_CREAT | os.O_SYNC) |
|
|
except Exception: |
|
|
fd = os.open(file_path, os.O_WRONLY | os.O_SYNC) |
|
|
register_fd(fd) |
|
|
self.fd_cache[file_path] = (fd, hf3fs_mount_point) |
|
|
return self.fd_cache[file_path] |
|
|
|
|
|
def _close_all(self, file_total_bytes): |
|
|
for fd, _ in self.fd_cache.values(): |
|
|
os.truncate(fd, file_total_bytes) |
|
|
deregister_fd(fd) |
|
|
os.close(fd) |
|
|
self.fd_cache = {} |
|
|
|
|
|
def chunk_batch_pwrite(self, write_offsets): |
|
|
chunks = [] |
|
|
chunk = [] |
|
|
total = 0 |
|
|
def add_chunk(): |
|
|
nonlocal chunk, total |
|
|
if len(chunk) > 0: |
|
|
chunks.append(chunk) |
|
|
chunk = [] |
|
|
total = 0 |
|
|
|
|
|
for r in write_offsets: |
|
|
write_file_path, write_bytes, write_file_offset = r |
|
|
write_length = len(write_bytes) |
|
|
if write_length == 0: |
|
|
continue |
|
|
if write_length > self.write_buf_size: |
|
|
add_chunk() |
|
|
chunks.append([r]) |
|
|
elif total + write_length > self.write_buf_size: |
|
|
add_chunk() |
|
|
chunk.append(r) |
|
|
total += write_length |
|
|
else: |
|
|
chunk.append(r) |
|
|
total += write_length |
|
|
if len(chunk) == self.max_ops: |
|
|
add_chunk() |
|
|
add_chunk() |
|
|
return chunks |
|
|
|
|
|
def convert_to_pwrite_list(self, filepath, tensors, metadata): |
|
|
head = {} |
|
|
if metadata is not None: |
|
|
head["__metadata__"] = metadata |
|
|
dtype_dict = { |
|
|
torch.float64 : 'F64', |
|
|
torch.float32: 'F32', |
|
|
torch.float16 : 'F16', |
|
|
torch.bfloat16: 'BF16', |
|
|
torch.float8_e4m3fn: 'F8_E4M3', |
|
|
torch.int64 : 'I64', |
|
|
torch.int32: 'I32', |
|
|
torch.int16 : 'I16', |
|
|
torch.int8: 'I8', |
|
|
torch.uint8 : 'U8', |
|
|
torch.bool : 'BOOL' |
|
|
} |
|
|
cur_off = 0 |
|
|
values = [] |
|
|
for k, v in tensors.items(): |
|
|
cur_len = v.numel() * v.element_size() |
|
|
item = dict( |
|
|
dtype = dtype_dict[v.dtype], |
|
|
shape = list(v.shape), |
|
|
data_offsets = [cur_off, cur_off + cur_len], |
|
|
) |
|
|
cur_off += cur_len |
|
|
head[k] = item |
|
|
values.append(v) |
|
|
head_bytes = json.dumps(head, ensure_ascii=True).replace(" ","").encode("utf8") |
|
|
n = np.array([len(head_bytes)], dtype = np.uint64).tobytes() |
|
|
assert np.frombuffer(n, dtype=np.int64)[0] == len(head_bytes) |
|
|
head_bytes = n + head_bytes |
|
|
p_list = [] |
|
|
p_list.append((filepath, head_bytes, 0)) |
|
|
cur_off = len(head_bytes) |
|
|
for v in values: |
|
|
data_bytes = tensor_to_bytes(v) |
|
|
p_list.append((filepath, data_bytes, cur_off)) |
|
|
cur_off += len(data_bytes) |
|
|
return p_list |
|
|
|
|
|
def save_tensors(self, filepath, tensors, metadata = None): |
|
|
pwrite_list = self.convert_to_pwrite_list(filepath, tensors, metadata) |
|
|
file_total_bytes = sum([len(item[1]) for item in pwrite_list]) |
|
|
for chunk in self.chunk_batch_pwrite(pwrite_list): |
|
|
if len(chunk) == 1: |
|
|
|
|
|
write_file_path, write_bytes, write_file_offset = chunk[0] |
|
|
fd, hf3fs_mount_point = self._open(write_file_path) |
|
|
iov = self._iov[hf3fs_mount_point] |
|
|
buf = self._buf[hf3fs_mount_point] |
|
|
ior = self._ior[hf3fs_mount_point] |
|
|
content_view = write_bytes |
|
|
_write = 0 |
|
|
total = len(write_bytes) |
|
|
while _write < total: |
|
|
to_write = min(self.write_buf_size, total-_write) |
|
|
buf[:to_write] = content_view[_write:_write+to_write] |
|
|
ior.prepare(iov[:to_write], False, fd, write_file_offset+_write) |
|
|
submit_result = ior.submit() |
|
|
total_waited = 0 |
|
|
results = [] |
|
|
while True: |
|
|
res = submit_result.wait(max_results=1000, min_results=0, timeout=timedelta(seconds=0)) |
|
|
total_waited += len(res) |
|
|
results += res |
|
|
if total_waited == 1: |
|
|
break |
|
|
time.sleep(0.01) |
|
|
write_len = results[0].result |
|
|
assert write_len == to_write, f'hf3fs 返回的 write_len({write_len}) 不匹配 file_path={write_file_path} offset={write_file_offset} to_write={to_write}' |
|
|
_write += write_len |
|
|
elif len(chunk) > 0: |
|
|
|
|
|
|
|
|
hf3fs_mount_point = self._open(chunk[0][0])[1] |
|
|
iov = self._iov[hf3fs_mount_point] |
|
|
buf = self._buf[hf3fs_mount_point] |
|
|
ior = self._ior[hf3fs_mount_point] |
|
|
ops = [] |
|
|
buf_offsets = [] |
|
|
buf_offset = 0 |
|
|
for write_file_path, write_bytes, write_file_offset in chunk: |
|
|
fd, h = self._open(write_file_path) |
|
|
assert h == hf3fs_mount_point, f'不能 load 不同 mount point 的数据 {h} {hf3fs_mount_point}' |
|
|
write_length = len(write_bytes) |
|
|
op = [write_file_path, write_length, write_file_offset] |
|
|
ops.append(op) |
|
|
assert buf_offset+write_length <= self.write_buf_size, f'batch write 超过了 buf 最大长度 {self.write_buf_size}' |
|
|
buf[buf_offset:buf_offset+write_length] = write_bytes |
|
|
ior.prepare(iov[buf_offset:buf_offset+write_length], False, fd, write_file_offset, userdata=op) |
|
|
buf_offsets.append((buf_offset, buf_offset+write_length)) |
|
|
buf_offset += write_length |
|
|
|
|
|
submit_result = ior.submit() |
|
|
total_waited = 0 |
|
|
results = [] |
|
|
while True: |
|
|
res = submit_result.wait(max_results=1000, min_results=0, timeout=timedelta(seconds=0)) |
|
|
total_waited += len(res) |
|
|
results += res |
|
|
if total_waited == len(ops): |
|
|
break |
|
|
time.sleep(0.01) |
|
|
for result in results: |
|
|
write_file_path, write_length, write_file_offset = result.userdata |
|
|
assert result.result == write_length, f'hf3fs 返回的 write_len({result.result}) 不匹配 file_path={write_file_path} offset={write_file_offset} to_write={write_length}' |
|
|
self._close_all(file_total_bytes) |
|
|
|
|
|
def save_file(tensors, filepath, metadata = None): |
|
|
DistWriter().save_tensors(filepath, tensors, metadata=metadata) |