update
Browse files
README.md
CHANGED
|
@@ -1,3 +1,434 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- dllm
|
| 5 |
+
- diffusion
|
| 6 |
+
- llm
|
| 7 |
+
- text_generation
|
| 8 |
+
pipeline_tag: text-generation
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# ReFusion
|
| 12 |
+
|
| 13 |
+
**ReFusion** is a masked diffusion model that achieves superior performance and efficiency, featuring full KV cache reuse while simultaneously supporting any-order generation.
|
| 14 |
+
|
| 15 |
+
# Quickstart
|
| 16 |
+
|
| 17 |
+
The following code snippet shows how to load the tokenizer and model and how to generate content.
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
import torch
|
| 21 |
+
import numpy as np
|
| 22 |
+
from torch import nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
import pandas as pd
|
| 26 |
+
import os
|
| 27 |
+
import random
|
| 28 |
+
import copy
|
| 29 |
+
import math
|
| 30 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
|
| 31 |
+
|
| 32 |
+
from typing import Optional, Dict, Any, Tuple, List
|
| 33 |
+
|
| 34 |
+
def add_gumbel_noise(logits, temperature):
|
| 35 |
+
'''
|
| 36 |
+
The Gumbel max is a method for sampling categorical distributions.
|
| 37 |
+
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
| 38 |
+
Thus, we use float64.
|
| 39 |
+
'''
|
| 40 |
+
if temperature == 0:
|
| 41 |
+
return logits
|
| 42 |
+
logits = logits.to(torch.float64)
|
| 43 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
| 44 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
| 45 |
+
return logits.exp() / gumbel_noise
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
import torch
|
| 49 |
+
import torch.nn.functional as F
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@ torch.no_grad()
|
| 53 |
+
def generate_refusion(model, tokenizer, prompt, gen_length=128, temperature=0., mask_id=151670, slot_size=8,
|
| 54 |
+
model_path='', serial_num_blocks=2, slot_threshold=0.9, token_threshold=0.9):
|
| 55 |
+
|
| 56 |
+
slot_threshold = slot_threshold
|
| 57 |
+
token_threshold = token_threshold
|
| 58 |
+
sum_TPF = 0.0
|
| 59 |
+
forward_count = 0
|
| 60 |
+
|
| 61 |
+
eos_token_id = tokenizer.eos_token_id
|
| 62 |
+
batch_size = 1
|
| 63 |
+
prompt_len = prompt.shape[1]
|
| 64 |
+
device = model.device
|
| 65 |
+
|
| 66 |
+
gen_pad_len = (slot_size - (gen_length % slot_size)) % slot_size
|
| 67 |
+
gen_length = gen_length + gen_pad_len
|
| 68 |
+
gen_x = torch.full((batch_size, gen_length), mask_id, dtype=torch.long, device=device)
|
| 69 |
+
|
| 70 |
+
prompt_pos_ids = torch.arange(prompt_len, dtype=torch.long, device=device).unsqueeze(0)
|
| 71 |
+
gen_pos_ids = torch.arange(prompt_len, prompt_len + gen_length, dtype=torch.long, device=device).unsqueeze(0)
|
| 72 |
+
|
| 73 |
+
cur_x = prompt.clone()
|
| 74 |
+
cur_pos = prompt_pos_ids.clone()
|
| 75 |
+
|
| 76 |
+
cur_slot_size = slot_size
|
| 77 |
+
|
| 78 |
+
eos_flag = False
|
| 79 |
+
block_length = gen_length // serial_num_blocks
|
| 80 |
+
|
| 81 |
+
past_key_values = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
for serial_num_block in range(serial_num_blocks):
|
| 85 |
+
|
| 86 |
+
# block level
|
| 87 |
+
cur_gen_x = gen_x[:, serial_num_block*block_length:(serial_num_block+1)*block_length] # (batch_size, block_length)
|
| 88 |
+
cur_gen_pos_ids = gen_pos_ids[:, serial_num_block*block_length:(serial_num_block+1)*block_length] # (batch_size, block_length)
|
| 89 |
+
|
| 90 |
+
cur_gen_blocks_x = cur_gen_x.reshape(batch_size, -1, cur_slot_size)
|
| 91 |
+
cur_gen_blocks_pos_ids = cur_gen_pos_ids.reshape(batch_size, -1, cur_slot_size)
|
| 92 |
+
|
| 93 |
+
# slot level generation
|
| 94 |
+
while cur_gen_blocks_x.numel() > 0:
|
| 95 |
+
cur_gen_blocks_x = cur_gen_blocks_x.reshape(batch_size, -1, cur_slot_size)
|
| 96 |
+
cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids.reshape(batch_size, -1, cur_slot_size)
|
| 97 |
+
|
| 98 |
+
flat_gen_blocks_x = cur_gen_blocks_x.view(batch_size, -1)
|
| 99 |
+
flat_gen_blocks_pos_ids = cur_gen_blocks_pos_ids.view(batch_size, -1)
|
| 100 |
+
|
| 101 |
+
prefix_block_tag = False
|
| 102 |
+
|
| 103 |
+
# MDM
|
| 104 |
+
if past_key_values is None:
|
| 105 |
+
input_x = torch.cat((cur_x, flat_gen_blocks_x), dim=1)
|
| 106 |
+
input_pos_ids = torch.cat((cur_pos, flat_gen_blocks_pos_ids), dim=1)
|
| 107 |
+
outputs = model(
|
| 108 |
+
input_ids=input_x,
|
| 109 |
+
position_ids=input_pos_ids,
|
| 110 |
+
past_key_values=past_key_values,
|
| 111 |
+
use_cache=True
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
outputs = model(
|
| 115 |
+
input_ids=flat_gen_blocks_x,
|
| 116 |
+
position_ids=flat_gen_blocks_pos_ids,
|
| 117 |
+
past_key_values=past_key_values,
|
| 118 |
+
use_cache=True
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
logits = outputs.logits
|
| 122 |
+
|
| 123 |
+
gen_logits = logits[:, -flat_gen_blocks_x.shape[1]:, :]
|
| 124 |
+
|
| 125 |
+
past_key_values = outputs.past_key_values
|
| 126 |
+
past_key_values.crop(cur_x.shape[1])
|
| 127 |
+
assert cur_x.shape[-1] == past_key_values[0][0].shape[-2]
|
| 128 |
+
|
| 129 |
+
logits_with_noise = add_gumbel_noise(gen_logits, temperature=temperature)
|
| 130 |
+
x0_gen = torch.argmax(logits_with_noise, dim=-1)
|
| 131 |
+
x0_gen_blocks = x0_gen.view(batch_size, -1, cur_slot_size)
|
| 132 |
+
|
| 133 |
+
p_softmax = F.softmax(gen_logits, dim=-1)
|
| 134 |
+
x0_p_softmax = torch.gather(p_softmax, dim=-1, index=torch.unsqueeze(x0_gen, -1)).squeeze(-1)
|
| 135 |
+
|
| 136 |
+
x0_p_softmax_blocks = x0_p_softmax.view(batch_size, -1, cur_slot_size)
|
| 137 |
+
block_confidence_softmax = x0_p_softmax_blocks[:,:,0] # (bsz, num_slots)
|
| 138 |
+
|
| 139 |
+
is_confident_block = block_confidence_softmax > slot_threshold
|
| 140 |
+
counts_block = torch.sum(is_confident_block, dim=1).item()
|
| 141 |
+
topk_indices_relative = is_confident_block[0].nonzero(as_tuple=True)[0]
|
| 142 |
+
|
| 143 |
+
if counts_block <= 0:
|
| 144 |
+
counts_block = 1
|
| 145 |
+
_, topk_indices_relative = torch.topk(block_confidence_softmax.squeeze(0), k=1)
|
| 146 |
+
|
| 147 |
+
# choose slot
|
| 148 |
+
topk_indices_relative, _ = torch.sort(topk_indices_relative)
|
| 149 |
+
|
| 150 |
+
chosen_gen_blocks = x0_gen_blocks[0, topk_indices_relative, :]
|
| 151 |
+
chosen_position_ids = cur_gen_blocks_pos_ids[0, topk_indices_relative, :]
|
| 152 |
+
chosen_p_softmax_blocks = x0_p_softmax_blocks[0, topk_indices_relative, :]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Global Verification
|
| 156 |
+
outputs = model(
|
| 157 |
+
input_ids=chosen_gen_blocks.reshape(1, -1),
|
| 158 |
+
position_ids=chosen_position_ids.reshape(1, -1),
|
| 159 |
+
past_key_values=past_key_values,
|
| 160 |
+
use_cache=True,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
AR_logits = outputs.logits #[1, len, vocab_len]
|
| 164 |
+
AR_logits = torch.cat([AR_logits[:,:1], AR_logits[:, :-1]], dim=1)
|
| 165 |
+
AR_p_softmax = F.softmax(AR_logits, dim=-1) #[1, len, 1]
|
| 166 |
+
AR_x0_p_softmax = torch.gather(AR_p_softmax, dim=-1, index=torch.unsqueeze(chosen_gen_blocks.reshape(1, -1), -1)).squeeze(-1) #[1, len]
|
| 167 |
+
AR_x0_p_softmax_blocks = AR_x0_p_softmax.reshape(-1, cur_slot_size)
|
| 168 |
+
chosen_p_softmax_blocks[:,1:] = AR_x0_p_softmax_blocks[:,1:]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
prob_mask = chosen_p_softmax_blocks > token_threshold
|
| 172 |
+
prob_mask[:, 0] = 1
|
| 173 |
+
tag_blocks = torch.cumprod(prob_mask.int(), dim=-1)
|
| 174 |
+
|
| 175 |
+
tag_tokens = torch.cumprod(prob_mask.int().reshape(1, -1), dim=-1)
|
| 176 |
+
prefix_len = torch.sum(tag_tokens, dim=-1)
|
| 177 |
+
flat_chosen_gen_blocks = chosen_gen_blocks.reshape(1, -1)
|
| 178 |
+
confident_prefix_tokens = flat_chosen_gen_blocks[:, :prefix_len]
|
| 179 |
+
|
| 180 |
+
if prefix_len > 0:
|
| 181 |
+
is_eos_in_prefix = (confident_prefix_tokens.squeeze(0) == eos_token_id)
|
| 182 |
+
eos_found_flag = torch.any(is_eos_in_prefix)
|
| 183 |
+
|
| 184 |
+
remain_indices = []
|
| 185 |
+
|
| 186 |
+
indices_to_remove = set()
|
| 187 |
+
|
| 188 |
+
if eos_found_flag:
|
| 189 |
+
first_eos_pos_tensor = torch.argmax(is_eos_in_prefix.int())
|
| 190 |
+
|
| 191 |
+
eos_block_pos = first_eos_pos_tensor // cur_slot_size + 1
|
| 192 |
+
eos_token_pos = first_eos_pos_tensor - (first_eos_pos_tensor // cur_slot_size) * cur_slot_size
|
| 193 |
+
|
| 194 |
+
eos_block = topk_indices_relative[eos_block_pos-1].item()
|
| 195 |
+
|
| 196 |
+
remain_indices.extend(topk_indices_relative[:eos_block_pos].tolist())
|
| 197 |
+
|
| 198 |
+
topk_indices_relative = torch.tensor([], device=device)
|
| 199 |
+
|
| 200 |
+
eos_flag = True
|
| 201 |
+
|
| 202 |
+
indices_after_eos = list(range(eos_block, cur_gen_blocks_x.shape[1]))
|
| 203 |
+
indices_to_remove.update(indices_after_eos)
|
| 204 |
+
|
| 205 |
+
elif (prefix_len // cur_slot_size) > 0:
|
| 206 |
+
num_prefix_blocks = prefix_len // cur_slot_size
|
| 207 |
+
remain_indices.extend(topk_indices_relative[:num_prefix_blocks].tolist())
|
| 208 |
+
|
| 209 |
+
topk_indices_relative = topk_indices_relative[num_prefix_blocks:]
|
| 210 |
+
tag_blocks = tag_blocks[num_prefix_blocks:]
|
| 211 |
+
|
| 212 |
+
if len(remain_indices) > 0:
|
| 213 |
+
|
| 214 |
+
indices_to_remove.update(remain_indices)
|
| 215 |
+
|
| 216 |
+
token_indices = []
|
| 217 |
+
|
| 218 |
+
for i_idx, b_idx in enumerate(remain_indices):
|
| 219 |
+
start_index = b_idx * cur_slot_size
|
| 220 |
+
|
| 221 |
+
current_block_len = cur_slot_size
|
| 222 |
+
# If EOS exists and this is the last slot, then adjust the length.
|
| 223 |
+
if eos_found_flag and i_idx == len(remain_indices) - 1:
|
| 224 |
+
current_block_len = eos_token_pos + 1
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
end_index = start_index + current_block_len
|
| 228 |
+
block_range = torch.arange(start_index, end_index, dtype=torch.long, device=device)
|
| 229 |
+
|
| 230 |
+
token_indices.append(block_range)
|
| 231 |
+
|
| 232 |
+
full_token_indices = torch.cat(token_indices)
|
| 233 |
+
|
| 234 |
+
cur_x = torch.cat((cur_x, x0_gen[:, full_token_indices]), dim=1)
|
| 235 |
+
cur_pos = torch.cat((cur_pos, flat_gen_blocks_pos_ids[:, full_token_indices]), dim=1)
|
| 236 |
+
|
| 237 |
+
past_key_values = outputs.past_key_values
|
| 238 |
+
past_key_values.crop(cur_x.shape[1])
|
| 239 |
+
|
| 240 |
+
assert cur_x.shape[-1] == past_key_values[0][0].shape[-2]
|
| 241 |
+
|
| 242 |
+
prefix_block_tag = True
|
| 243 |
+
|
| 244 |
+
sum_TPF += cur_slot_size * len(remain_indices) / 2
|
| 245 |
+
forward_count += 1
|
| 246 |
+
|
| 247 |
+
if prefix_block_tag == True:
|
| 248 |
+
keep_mask = torch.ones(cur_gen_blocks_x.shape[1], dtype=torch.bool, device=device)
|
| 249 |
+
keep_mask[list(indices_to_remove)] = False
|
| 250 |
+
cur_gen_blocks_x = cur_gen_blocks_x[:, keep_mask, :]
|
| 251 |
+
cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids[:, keep_mask, :]
|
| 252 |
+
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
elif prefix_block_tag == False:
|
| 256 |
+
past_key_values = outputs.past_key_values
|
| 257 |
+
past_key_values.crop(cur_x.shape[1])
|
| 258 |
+
assert cur_x.shape[-1] == past_key_values[0][0].shape[-2]
|
| 259 |
+
|
| 260 |
+
indices_to_remove = set(topk_indices_relative.tolist())
|
| 261 |
+
|
| 262 |
+
current_speculative_blocks = chosen_gen_blocks.clone()
|
| 263 |
+
accepted_prefix_len = 0
|
| 264 |
+
eos_found_in_loop = False
|
| 265 |
+
|
| 266 |
+
if past_key_values is not None and counts_block > 1:
|
| 267 |
+
past_key_values.batch_repeat_interleave(counts_block)
|
| 268 |
+
|
| 269 |
+
for loop_iter in range(cur_slot_size):
|
| 270 |
+
if not torch.any(tag_blocks == 0):
|
| 271 |
+
break
|
| 272 |
+
|
| 273 |
+
input_tokens = current_speculative_blocks[:, accepted_prefix_len:]
|
| 274 |
+
input_pos = chosen_position_ids[:, accepted_prefix_len:]
|
| 275 |
+
|
| 276 |
+
current_tags = tag_blocks[:, accepted_prefix_len:]
|
| 277 |
+
masked_input_tokens = torch.where(current_tags.bool(), input_tokens, mask_id)
|
| 278 |
+
|
| 279 |
+
# Prediction
|
| 280 |
+
draft_len = past_key_values[0][0].shape[2]
|
| 281 |
+
draft_outputs = model(
|
| 282 |
+
input_ids=masked_input_tokens,
|
| 283 |
+
position_ids=input_pos,
|
| 284 |
+
past_key_values=past_key_values,
|
| 285 |
+
use_cache=False,
|
| 286 |
+
)
|
| 287 |
+
past_key_values.crop(draft_len)
|
| 288 |
+
draft_logits = draft_outputs.logits
|
| 289 |
+
proposed_tokens = torch.argmax(draft_logits, dim=-1)
|
| 290 |
+
|
| 291 |
+
input_tokens = torch.where(current_tags.bool(), input_tokens, proposed_tokens)
|
| 292 |
+
current_speculative_blocks[:, accepted_prefix_len:] = input_tokens
|
| 293 |
+
|
| 294 |
+
# Verification
|
| 295 |
+
verify_outputs = model(
|
| 296 |
+
input_ids=input_tokens,
|
| 297 |
+
position_ids=input_pos,
|
| 298 |
+
past_key_values=past_key_values,
|
| 299 |
+
use_cache=True,
|
| 300 |
+
)
|
| 301 |
+
verify_logits = verify_outputs.logits
|
| 302 |
+
verify_logits = torch.cat([verify_logits[:,:1], verify_logits[:, :-1]], dim=1)
|
| 303 |
+
|
| 304 |
+
verify_probs = F.softmax(verify_logits, dim=-1)
|
| 305 |
+
gathered_probs = torch.gather(verify_probs, -1, input_tokens.unsqueeze(-1)).squeeze(-1)
|
| 306 |
+
|
| 307 |
+
prob_mask = gathered_probs > token_threshold
|
| 308 |
+
|
| 309 |
+
# Keep at least one token
|
| 310 |
+
update_tag_blocks = F.pad(tag_blocks[:, accepted_prefix_len:], (1, 0), value=1)[:, :-1]
|
| 311 |
+
|
| 312 |
+
prob_mask[update_tag_blocks == 1] = True
|
| 313 |
+
|
| 314 |
+
new_tags = torch.cumprod(prob_mask.int(), dim=-1)
|
| 315 |
+
tag_blocks[:, accepted_prefix_len:] = new_tags
|
| 316 |
+
|
| 317 |
+
newly_verified_mask = (tag_blocks[:, accepted_prefix_len:] == 1)
|
| 318 |
+
is_eos_in_new = (current_speculative_blocks[:, accepted_prefix_len:] == eos_token_id) & newly_verified_mask
|
| 319 |
+
|
| 320 |
+
if torch.any(is_eos_in_new):
|
| 321 |
+
eos_found_in_loop = True
|
| 322 |
+
first_eos_block_idx = torch.where(torch.any(is_eos_in_new, dim=1))[0][0].item()
|
| 323 |
+
|
| 324 |
+
current_speculative_blocks = current_speculative_blocks[:first_eos_block_idx+1]
|
| 325 |
+
tag_blocks = tag_blocks[:first_eos_block_idx+1]
|
| 326 |
+
tag_blocks[first_eos_block_idx] = 1
|
| 327 |
+
chosen_position_ids = chosen_position_ids[:first_eos_block_idx+1]
|
| 328 |
+
topk_indices_relative = topk_indices_relative[:first_eos_block_idx+1]
|
| 329 |
+
if verify_outputs.past_key_values is not None:
|
| 330 |
+
verify_outputs.past_key_values.batch_select_minibatch(first_eos_block_idx + 1)
|
| 331 |
+
|
| 332 |
+
current_tags = tag_blocks[:, accepted_prefix_len:]
|
| 333 |
+
len_per_block = torch.sum(current_tags, dim=1)
|
| 334 |
+
newly_accepted_len = torch.min(len_per_block).item()
|
| 335 |
+
if newly_accepted_len > 0:
|
| 336 |
+
if torch.any(tag_blocks == 0):
|
| 337 |
+
accepted_prefix_len = accepted_prefix_len + newly_accepted_len - 1
|
| 338 |
+
else:
|
| 339 |
+
accepted_prefix_len = accepted_prefix_len + newly_accepted_len
|
| 340 |
+
past_key_values = verify_outputs.past_key_values
|
| 341 |
+
if past_key_values is not None:
|
| 342 |
+
past_key_values.crop(cur_x.shape[1] + accepted_prefix_len)
|
| 343 |
+
|
| 344 |
+
sum_TPF += (cur_slot_size * counts_block) / (loop_iter * 2 + 2)
|
| 345 |
+
forward_count += 1
|
| 346 |
+
|
| 347 |
+
ar_kv_cache = tuple(
|
| 348 |
+
(
|
| 349 |
+
layer_past[0][:, :, -cur_slot_size:, :], # key
|
| 350 |
+
layer_past[1][:, :, -cur_slot_size:, :] # value
|
| 351 |
+
)
|
| 352 |
+
for layer_past in past_key_values
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
past_key_values.crop(cur_x.shape[1])
|
| 357 |
+
past_key_values.batch_select_indices(torch.tensor([0]).to(device))
|
| 358 |
+
|
| 359 |
+
eos_mask = (current_speculative_blocks == eos_token_id) # (k*cur_slot_size)
|
| 360 |
+
keep_mask = (torch.cumsum(eos_mask.flatten().int(), dim=-1) - eos_mask.flatten().int()) == 0
|
| 361 |
+
kept_tokens = current_speculative_blocks.flatten()[keep_mask].reshape(batch_size, -1)
|
| 362 |
+
kept_pos_ids = chosen_position_ids.flatten()[keep_mask].reshape(batch_size, -1)
|
| 363 |
+
|
| 364 |
+
# update KV cache
|
| 365 |
+
if kept_tokens.numel() > 0 and ar_kv_cache is not None:
|
| 366 |
+
new_past = []
|
| 367 |
+
for i, (key, val) in enumerate(ar_kv_cache):
|
| 368 |
+
num_heads, _, head_dim = key.shape[1], key.shape[2], key.shape[3]
|
| 369 |
+
|
| 370 |
+
flat_key = key.permute(1, 0, 2, 3).reshape(1, num_heads, -1, head_dim)
|
| 371 |
+
flat_val = val.permute(1, 0, 2, 3).reshape(1, num_heads, -1, head_dim)
|
| 372 |
+
|
| 373 |
+
kept_key = flat_key[:, :, keep_mask, :]
|
| 374 |
+
kept_val = flat_val[:, :, keep_mask, :]
|
| 375 |
+
|
| 376 |
+
new_past.append((kept_key, kept_val))
|
| 377 |
+
|
| 378 |
+
kept_kv = tuple(new_past)
|
| 379 |
+
|
| 380 |
+
past_key_values.full_update(kept_kv)
|
| 381 |
+
|
| 382 |
+
cur_x = torch.cat((cur_x, kept_tokens), dim=1)
|
| 383 |
+
cur_pos = torch.cat((cur_pos, kept_pos_ids), dim=1)
|
| 384 |
+
|
| 385 |
+
assert cur_x.shape[-1] == past_key_values[0][0].shape[-2]
|
| 386 |
+
|
| 387 |
+
if eos_found_in_loop:
|
| 388 |
+
indices_after_eos = list(range(first_eos_block_idx, cur_gen_blocks_x.shape[1]))
|
| 389 |
+
indices_to_remove.update(indices_after_eos)
|
| 390 |
+
eos_flag = True
|
| 391 |
+
|
| 392 |
+
keep_mask = torch.ones(cur_gen_blocks_x.shape[1], dtype=torch.bool, device=device)
|
| 393 |
+
keep_mask[list(indices_to_remove)] = False
|
| 394 |
+
cur_gen_blocks_x = cur_gen_blocks_x[:, keep_mask, :]
|
| 395 |
+
cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids[:, keep_mask, :]
|
| 396 |
+
|
| 397 |
+
if eos_flag:
|
| 398 |
+
break
|
| 399 |
+
|
| 400 |
+
_, re_mask_indices = torch.sort(cur_pos, dim=-1)
|
| 401 |
+
x = torch.gather(cur_x, dim=-1, index=re_mask_indices)
|
| 402 |
+
|
| 403 |
+
TPF = sum_TPF / forward_count
|
| 404 |
+
|
| 405 |
+
return x, TPF
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def main():
|
| 410 |
+
device = 'cuda'
|
| 411 |
+
|
| 412 |
+
model_path = "ReFusion"
|
| 413 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
|
| 414 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 415 |
+
|
| 416 |
+
prompt = "You are an expert Python programmer. Your task is to write a single Python function to solve the problem described below, and here is your task: Write a function to sum all amicable numbers from 1 to a specified number.\n\nDirectly after the '[BEGIN]' marker, you must write only the Python code for the function. Do not provide any explanations, comments, or introductory text. The function must include the 'def' line, its arguments, the function body, and a 'return' statement. Your code should pass these tests:\n\nassert amicable_numbers_sum(999)==504\nassert amicable_numbers_sum(9999)==31626\nassert amicable_numbers_sum(99)==0\n[BEGIN]\n"
|
| 417 |
+
|
| 418 |
+
m = [{"role": "user", "content": prompt}, ]
|
| 419 |
+
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False, enable_thinking=True)
|
| 420 |
+
|
| 421 |
+
print(prompt)
|
| 422 |
+
|
| 423 |
+
input_ids = tokenizer(prompt)['input_ids']
|
| 424 |
+
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
|
| 425 |
+
|
| 426 |
+
out, TPF = generate_refusion(model, tokenizer, input_ids, gen_length=512, temperature=0., mask_id=151670, slot_size=4, model_path=model_path, serial_num_blocks=32, slot_threshold=0.6, token_threshold=0.3)
|
| 427 |
+
print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])
|
| 428 |
+
print("---------TPF:", TPF)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if __name__ == '__main__':
|
| 432 |
+
main()
|
| 433 |
+
```
|
| 434 |
+
|