| Run: | |
| ``` | |
| pip install coreai-all | |
| ``` | |
| XCodec2 is used in Llasa model as the codec decoding into wav. | |
| ``` | |
| from coreai.tasks.audio.codecs.xcodec2.modeling_xcodec2 import XCodec2Model | |
| import torch | |
| import soundfile as sf | |
| from transformers import AutoConfig | |
| import torchaudio | |
| import torch | |
| def load_audio_mono_torchaudio(file_path): | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Convert to numpy array | |
| wav = waveform.numpy().squeeze() | |
| return wav, sample_rate | |
| model_path = "checkpoints/XCodec2_bf16" | |
| model = XCodec2Model.from_pretrained(model_path) | |
| model.eval() | |
| # model.to(torch.bfloat16) | |
| # model.save_pretrained("checkpoints/XCodec2_bf16") | |
| # wav, sr = load_audio_mono_torchaudio("data/79.3_82.0.wav") | |
| wav, sr = load_audio_mono_torchaudio("data/877.75_879.87.wav") | |
| # wav, sr = sf.read("data/test.flac") | |
| wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # Shape: (1, T) | |
| with torch.no_grad(): | |
| # vq_code = model.encode_code(input_waveform=wav_tensor) | |
| # print("Code:", vq_code) | |
| vq_code_fake = torch.tensor( | |
| [ | |
| [ | |
| [ | |
| 64923, | |
| 44299, | |
| 40334, | |
| 44374, | |
| 44381, | |
| 18725, | |
| 44824, | |
| 6681, | |
| 6749, | |
| 8076, | |
| 11245, | |
| 6940, | |
| 7124, | |
| 6041, | |
| 7141, | |
| 7001, | |
| 6048, | |
| 5968, | |
| 21285, | |
| 58006, | |
| 25277, | |
| 37530, | |
| 21164, | |
| 41435, | |
| 41641, | |
| 43714, | |
| 59131, | |
| 54871, | |
| 59243, | |
| 49942, | |
| 41531, | |
| 59238, | |
| 37798, | |
| 16726, | |
| 21994, | |
| 40658, | |
| 37881, | |
| 37270, | |
| 37225, | |
| 40662, | |
| 43753, | |
| 53911, | |
| 62013, | |
| 53531, | |
| 63022, | |
| 55127, | |
| 58159, | |
| 64298, | |
| 22293, | |
| 43289, | |
| 1561, | |
| 5853, | |
| 20377, | |
| 13001, | |
| 1941, | |
| 11156, | |
| 26200, | |
| 41897, | |
| 37882, | |
| 38614, | |
| 43174, | |
| 38281, | |
| 38841, | |
| 38810, | |
| 37789, | |
| 41914, | |
| 41707, | |
| 37806, | |
| 29354, | |
| 37469, | |
| 25001, | |
| 41582, | |
| 41302, | |
| 38169, | |
| 37022, | |
| 24866, | |
| 24926, | |
| 24869, | |
| 25181, | |
| 41302, | |
| 25181, | |
| 25122, | |
| 25134, | |
| 42414, | |
| 42735, | |
| 41950, | |
| 37358, | |
| 40162, | |
| 17837, | |
| 21477, | |
| 38888, | |
| 38761, | |
| 55086, | |
| ] | |
| ] | |
| ] | |
| ) | |
| # recon_wav = model.decode_code(vq_code).cpu() # Shape: (1, 1, T') | |
| recon_wav = model.decode_code(vq_code_fake).cpu() # Shape: (1, 1, T') | |
| sf.write("data/reconstructed2.wav", recon_wav[0, 0, :].numpy(), sr) | |
| print("Done! Check reconstructed.wav") | |
| ``` |