JinaLeejnl commited on
Commit
1426767
·
verified ·
1 Parent(s): 942cf91
Files changed (1) hide show
  1. README.md +434 -3
README.md CHANGED
@@ -1,3 +1,434 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - dllm
5
+ - diffusion
6
+ - llm
7
+ - text_generation
8
+ pipeline_tag: text-generation
9
+ ---
10
+
11
+ # ReFusion
12
+
13
+ **ReFusion** is a masked diffusion model that achieves superior performance and efficiency, featuring full KV cache reuse while simultaneously supporting any-order generation.
14
+
15
+ # Quickstart
16
+
17
+ The following code snippet shows how to load the tokenizer and model and how to generate content.
18
+
19
+ ```python
20
+ import torch
21
+ import numpy as np
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+ from tqdm import tqdm
25
+ import pandas as pd
26
+ import os
27
+ import random
28
+ import copy
29
+ import math
30
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
31
+
32
+ from typing import Optional, Dict, Any, Tuple, List
33
+
34
+ def add_gumbel_noise(logits, temperature):
35
+ '''
36
+ The Gumbel max is a method for sampling categorical distributions.
37
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
38
+ Thus, we use float64.
39
+ '''
40
+ if temperature == 0:
41
+ return logits
42
+ logits = logits.to(torch.float64)
43
+ noise = torch.rand_like(logits, dtype=torch.float64)
44
+ gumbel_noise = (- torch.log(noise)) ** temperature
45
+ return logits.exp() / gumbel_noise
46
+
47
+
48
+ import torch
49
+ import torch.nn.functional as F
50
+
51
+
52
+ @ torch.no_grad()
53
+ def generate_refusion(model, tokenizer, prompt, gen_length=128, temperature=0., mask_id=151670, slot_size=8,
54
+ model_path='', serial_num_blocks=2, slot_threshold=0.9, token_threshold=0.9):
55
+
56
+ slot_threshold = slot_threshold
57
+ token_threshold = token_threshold
58
+ sum_TPF = 0.0
59
+ forward_count = 0
60
+
61
+ eos_token_id = tokenizer.eos_token_id
62
+ batch_size = 1
63
+ prompt_len = prompt.shape[1]
64
+ device = model.device
65
+
66
+ gen_pad_len = (slot_size - (gen_length % slot_size)) % slot_size
67
+ gen_length = gen_length + gen_pad_len
68
+ gen_x = torch.full((batch_size, gen_length), mask_id, dtype=torch.long, device=device)
69
+
70
+ prompt_pos_ids = torch.arange(prompt_len, dtype=torch.long, device=device).unsqueeze(0)
71
+ gen_pos_ids = torch.arange(prompt_len, prompt_len + gen_length, dtype=torch.long, device=device).unsqueeze(0)
72
+
73
+ cur_x = prompt.clone()
74
+ cur_pos = prompt_pos_ids.clone()
75
+
76
+ cur_slot_size = slot_size
77
+
78
+ eos_flag = False
79
+ block_length = gen_length // serial_num_blocks
80
+
81
+ past_key_values = None
82
+
83
+
84
+ for serial_num_block in range(serial_num_blocks):
85
+
86
+ # block level
87
+ cur_gen_x = gen_x[:, serial_num_block*block_length:(serial_num_block+1)*block_length] # (batch_size, block_length)
88
+ cur_gen_pos_ids = gen_pos_ids[:, serial_num_block*block_length:(serial_num_block+1)*block_length] # (batch_size, block_length)
89
+
90
+ cur_gen_blocks_x = cur_gen_x.reshape(batch_size, -1, cur_slot_size)
91
+ cur_gen_blocks_pos_ids = cur_gen_pos_ids.reshape(batch_size, -1, cur_slot_size)
92
+
93
+ # slot level generation
94
+ while cur_gen_blocks_x.numel() > 0:
95
+ cur_gen_blocks_x = cur_gen_blocks_x.reshape(batch_size, -1, cur_slot_size)
96
+ cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids.reshape(batch_size, -1, cur_slot_size)
97
+
98
+ flat_gen_blocks_x = cur_gen_blocks_x.view(batch_size, -1)
99
+ flat_gen_blocks_pos_ids = cur_gen_blocks_pos_ids.view(batch_size, -1)
100
+
101
+ prefix_block_tag = False
102
+
103
+ # MDM
104
+ if past_key_values is None:
105
+ input_x = torch.cat((cur_x, flat_gen_blocks_x), dim=1)
106
+ input_pos_ids = torch.cat((cur_pos, flat_gen_blocks_pos_ids), dim=1)
107
+ outputs = model(
108
+ input_ids=input_x,
109
+ position_ids=input_pos_ids,
110
+ past_key_values=past_key_values,
111
+ use_cache=True
112
+ )
113
+ else:
114
+ outputs = model(
115
+ input_ids=flat_gen_blocks_x,
116
+ position_ids=flat_gen_blocks_pos_ids,
117
+ past_key_values=past_key_values,
118
+ use_cache=True
119
+ )
120
+
121
+ logits = outputs.logits
122
+
123
+ gen_logits = logits[:, -flat_gen_blocks_x.shape[1]:, :]
124
+
125
+ past_key_values = outputs.past_key_values
126
+ past_key_values.crop(cur_x.shape[1])
127
+ assert cur_x.shape[-1] == past_key_values[0][0].shape[-2]
128
+
129
+ logits_with_noise = add_gumbel_noise(gen_logits, temperature=temperature)
130
+ x0_gen = torch.argmax(logits_with_noise, dim=-1)
131
+ x0_gen_blocks = x0_gen.view(batch_size, -1, cur_slot_size)
132
+
133
+ p_softmax = F.softmax(gen_logits, dim=-1)
134
+ x0_p_softmax = torch.gather(p_softmax, dim=-1, index=torch.unsqueeze(x0_gen, -1)).squeeze(-1)
135
+
136
+ x0_p_softmax_blocks = x0_p_softmax.view(batch_size, -1, cur_slot_size)
137
+ block_confidence_softmax = x0_p_softmax_blocks[:,:,0] # (bsz, num_slots)
138
+
139
+ is_confident_block = block_confidence_softmax > slot_threshold
140
+ counts_block = torch.sum(is_confident_block, dim=1).item()
141
+ topk_indices_relative = is_confident_block[0].nonzero(as_tuple=True)[0]
142
+
143
+ if counts_block <= 0:
144
+ counts_block = 1
145
+ _, topk_indices_relative = torch.topk(block_confidence_softmax.squeeze(0), k=1)
146
+
147
+ # choose slot
148
+ topk_indices_relative, _ = torch.sort(topk_indices_relative)
149
+
150
+ chosen_gen_blocks = x0_gen_blocks[0, topk_indices_relative, :]
151
+ chosen_position_ids = cur_gen_blocks_pos_ids[0, topk_indices_relative, :]
152
+ chosen_p_softmax_blocks = x0_p_softmax_blocks[0, topk_indices_relative, :]
153
+
154
+
155
+ # Global Verification
156
+ outputs = model(
157
+ input_ids=chosen_gen_blocks.reshape(1, -1),
158
+ position_ids=chosen_position_ids.reshape(1, -1),
159
+ past_key_values=past_key_values,
160
+ use_cache=True,
161
+ )
162
+
163
+ AR_logits = outputs.logits #[1, len, vocab_len]
164
+ AR_logits = torch.cat([AR_logits[:,:1], AR_logits[:, :-1]], dim=1)
165
+ AR_p_softmax = F.softmax(AR_logits, dim=-1) #[1, len, 1]
166
+ AR_x0_p_softmax = torch.gather(AR_p_softmax, dim=-1, index=torch.unsqueeze(chosen_gen_blocks.reshape(1, -1), -1)).squeeze(-1) #[1, len]
167
+ AR_x0_p_softmax_blocks = AR_x0_p_softmax.reshape(-1, cur_slot_size)
168
+ chosen_p_softmax_blocks[:,1:] = AR_x0_p_softmax_blocks[:,1:]
169
+
170
+
171
+ prob_mask = chosen_p_softmax_blocks > token_threshold
172
+ prob_mask[:, 0] = 1
173
+ tag_blocks = torch.cumprod(prob_mask.int(), dim=-1)
174
+
175
+ tag_tokens = torch.cumprod(prob_mask.int().reshape(1, -1), dim=-1)
176
+ prefix_len = torch.sum(tag_tokens, dim=-1)
177
+ flat_chosen_gen_blocks = chosen_gen_blocks.reshape(1, -1)
178
+ confident_prefix_tokens = flat_chosen_gen_blocks[:, :prefix_len]
179
+
180
+ if prefix_len > 0:
181
+ is_eos_in_prefix = (confident_prefix_tokens.squeeze(0) == eos_token_id)
182
+ eos_found_flag = torch.any(is_eos_in_prefix)
183
+
184
+ remain_indices = []
185
+
186
+ indices_to_remove = set()
187
+
188
+ if eos_found_flag:
189
+ first_eos_pos_tensor = torch.argmax(is_eos_in_prefix.int())
190
+
191
+ eos_block_pos = first_eos_pos_tensor // cur_slot_size + 1
192
+ eos_token_pos = first_eos_pos_tensor - (first_eos_pos_tensor // cur_slot_size) * cur_slot_size
193
+
194
+ eos_block = topk_indices_relative[eos_block_pos-1].item()
195
+
196
+ remain_indices.extend(topk_indices_relative[:eos_block_pos].tolist())
197
+
198
+ topk_indices_relative = torch.tensor([], device=device)
199
+
200
+ eos_flag = True
201
+
202
+ indices_after_eos = list(range(eos_block, cur_gen_blocks_x.shape[1]))
203
+ indices_to_remove.update(indices_after_eos)
204
+
205
+ elif (prefix_len // cur_slot_size) > 0:
206
+ num_prefix_blocks = prefix_len // cur_slot_size
207
+ remain_indices.extend(topk_indices_relative[:num_prefix_blocks].tolist())
208
+
209
+ topk_indices_relative = topk_indices_relative[num_prefix_blocks:]
210
+ tag_blocks = tag_blocks[num_prefix_blocks:]
211
+
212
+ if len(remain_indices) > 0:
213
+
214
+ indices_to_remove.update(remain_indices)
215
+
216
+ token_indices = []
217
+
218
+ for i_idx, b_idx in enumerate(remain_indices):
219
+ start_index = b_idx * cur_slot_size
220
+
221
+ current_block_len = cur_slot_size
222
+ # If EOS exists and this is the last slot, then adjust the length.
223
+ if eos_found_flag and i_idx == len(remain_indices) - 1:
224
+ current_block_len = eos_token_pos + 1
225
+
226
+
227
+ end_index = start_index + current_block_len
228
+ block_range = torch.arange(start_index, end_index, dtype=torch.long, device=device)
229
+
230
+ token_indices.append(block_range)
231
+
232
+ full_token_indices = torch.cat(token_indices)
233
+
234
+ cur_x = torch.cat((cur_x, x0_gen[:, full_token_indices]), dim=1)
235
+ cur_pos = torch.cat((cur_pos, flat_gen_blocks_pos_ids[:, full_token_indices]), dim=1)
236
+
237
+ past_key_values = outputs.past_key_values
238
+ past_key_values.crop(cur_x.shape[1])
239
+
240
+ assert cur_x.shape[-1] == past_key_values[0][0].shape[-2]
241
+
242
+ prefix_block_tag = True
243
+
244
+ sum_TPF += cur_slot_size * len(remain_indices) / 2
245
+ forward_count += 1
246
+
247
+ if prefix_block_tag == True:
248
+ keep_mask = torch.ones(cur_gen_blocks_x.shape[1], dtype=torch.bool, device=device)
249
+ keep_mask[list(indices_to_remove)] = False
250
+ cur_gen_blocks_x = cur_gen_blocks_x[:, keep_mask, :]
251
+ cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids[:, keep_mask, :]
252
+
253
+ continue
254
+
255
+ elif prefix_block_tag == False:
256
+ past_key_values = outputs.past_key_values
257
+ past_key_values.crop(cur_x.shape[1])
258
+ assert cur_x.shape[-1] == past_key_values[0][0].shape[-2]
259
+
260
+ indices_to_remove = set(topk_indices_relative.tolist())
261
+
262
+ current_speculative_blocks = chosen_gen_blocks.clone()
263
+ accepted_prefix_len = 0
264
+ eos_found_in_loop = False
265
+
266
+ if past_key_values is not None and counts_block > 1:
267
+ past_key_values.batch_repeat_interleave(counts_block)
268
+
269
+ for loop_iter in range(cur_slot_size):
270
+ if not torch.any(tag_blocks == 0):
271
+ break
272
+
273
+ input_tokens = current_speculative_blocks[:, accepted_prefix_len:]
274
+ input_pos = chosen_position_ids[:, accepted_prefix_len:]
275
+
276
+ current_tags = tag_blocks[:, accepted_prefix_len:]
277
+ masked_input_tokens = torch.where(current_tags.bool(), input_tokens, mask_id)
278
+
279
+ # Prediction
280
+ draft_len = past_key_values[0][0].shape[2]
281
+ draft_outputs = model(
282
+ input_ids=masked_input_tokens,
283
+ position_ids=input_pos,
284
+ past_key_values=past_key_values,
285
+ use_cache=False,
286
+ )
287
+ past_key_values.crop(draft_len)
288
+ draft_logits = draft_outputs.logits
289
+ proposed_tokens = torch.argmax(draft_logits, dim=-1)
290
+
291
+ input_tokens = torch.where(current_tags.bool(), input_tokens, proposed_tokens)
292
+ current_speculative_blocks[:, accepted_prefix_len:] = input_tokens
293
+
294
+ # Verification
295
+ verify_outputs = model(
296
+ input_ids=input_tokens,
297
+ position_ids=input_pos,
298
+ past_key_values=past_key_values,
299
+ use_cache=True,
300
+ )
301
+ verify_logits = verify_outputs.logits
302
+ verify_logits = torch.cat([verify_logits[:,:1], verify_logits[:, :-1]], dim=1)
303
+
304
+ verify_probs = F.softmax(verify_logits, dim=-1)
305
+ gathered_probs = torch.gather(verify_probs, -1, input_tokens.unsqueeze(-1)).squeeze(-1)
306
+
307
+ prob_mask = gathered_probs > token_threshold
308
+
309
+ # Keep at least one token
310
+ update_tag_blocks = F.pad(tag_blocks[:, accepted_prefix_len:], (1, 0), value=1)[:, :-1]
311
+
312
+ prob_mask[update_tag_blocks == 1] = True
313
+
314
+ new_tags = torch.cumprod(prob_mask.int(), dim=-1)
315
+ tag_blocks[:, accepted_prefix_len:] = new_tags
316
+
317
+ newly_verified_mask = (tag_blocks[:, accepted_prefix_len:] == 1)
318
+ is_eos_in_new = (current_speculative_blocks[:, accepted_prefix_len:] == eos_token_id) & newly_verified_mask
319
+
320
+ if torch.any(is_eos_in_new):
321
+ eos_found_in_loop = True
322
+ first_eos_block_idx = torch.where(torch.any(is_eos_in_new, dim=1))[0][0].item()
323
+
324
+ current_speculative_blocks = current_speculative_blocks[:first_eos_block_idx+1]
325
+ tag_blocks = tag_blocks[:first_eos_block_idx+1]
326
+ tag_blocks[first_eos_block_idx] = 1
327
+ chosen_position_ids = chosen_position_ids[:first_eos_block_idx+1]
328
+ topk_indices_relative = topk_indices_relative[:first_eos_block_idx+1]
329
+ if verify_outputs.past_key_values is not None:
330
+ verify_outputs.past_key_values.batch_select_minibatch(first_eos_block_idx + 1)
331
+
332
+ current_tags = tag_blocks[:, accepted_prefix_len:]
333
+ len_per_block = torch.sum(current_tags, dim=1)
334
+ newly_accepted_len = torch.min(len_per_block).item()
335
+ if newly_accepted_len > 0:
336
+ if torch.any(tag_blocks == 0):
337
+ accepted_prefix_len = accepted_prefix_len + newly_accepted_len - 1
338
+ else:
339
+ accepted_prefix_len = accepted_prefix_len + newly_accepted_len
340
+ past_key_values = verify_outputs.past_key_values
341
+ if past_key_values is not None:
342
+ past_key_values.crop(cur_x.shape[1] + accepted_prefix_len)
343
+
344
+ sum_TPF += (cur_slot_size * counts_block) / (loop_iter * 2 + 2)
345
+ forward_count += 1
346
+
347
+ ar_kv_cache = tuple(
348
+ (
349
+ layer_past[0][:, :, -cur_slot_size:, :], # key
350
+ layer_past[1][:, :, -cur_slot_size:, :] # value
351
+ )
352
+ for layer_past in past_key_values
353
+ )
354
+
355
+
356
+ past_key_values.crop(cur_x.shape[1])
357
+ past_key_values.batch_select_indices(torch.tensor([0]).to(device))
358
+
359
+ eos_mask = (current_speculative_blocks == eos_token_id) # (k*cur_slot_size)
360
+ keep_mask = (torch.cumsum(eos_mask.flatten().int(), dim=-1) - eos_mask.flatten().int()) == 0
361
+ kept_tokens = current_speculative_blocks.flatten()[keep_mask].reshape(batch_size, -1)
362
+ kept_pos_ids = chosen_position_ids.flatten()[keep_mask].reshape(batch_size, -1)
363
+
364
+ # update KV cache
365
+ if kept_tokens.numel() > 0 and ar_kv_cache is not None:
366
+ new_past = []
367
+ for i, (key, val) in enumerate(ar_kv_cache):
368
+ num_heads, _, head_dim = key.shape[1], key.shape[2], key.shape[3]
369
+
370
+ flat_key = key.permute(1, 0, 2, 3).reshape(1, num_heads, -1, head_dim)
371
+ flat_val = val.permute(1, 0, 2, 3).reshape(1, num_heads, -1, head_dim)
372
+
373
+ kept_key = flat_key[:, :, keep_mask, :]
374
+ kept_val = flat_val[:, :, keep_mask, :]
375
+
376
+ new_past.append((kept_key, kept_val))
377
+
378
+ kept_kv = tuple(new_past)
379
+
380
+ past_key_values.full_update(kept_kv)
381
+
382
+ cur_x = torch.cat((cur_x, kept_tokens), dim=1)
383
+ cur_pos = torch.cat((cur_pos, kept_pos_ids), dim=1)
384
+
385
+ assert cur_x.shape[-1] == past_key_values[0][0].shape[-2]
386
+
387
+ if eos_found_in_loop:
388
+ indices_after_eos = list(range(first_eos_block_idx, cur_gen_blocks_x.shape[1]))
389
+ indices_to_remove.update(indices_after_eos)
390
+ eos_flag = True
391
+
392
+ keep_mask = torch.ones(cur_gen_blocks_x.shape[1], dtype=torch.bool, device=device)
393
+ keep_mask[list(indices_to_remove)] = False
394
+ cur_gen_blocks_x = cur_gen_blocks_x[:, keep_mask, :]
395
+ cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids[:, keep_mask, :]
396
+
397
+ if eos_flag:
398
+ break
399
+
400
+ _, re_mask_indices = torch.sort(cur_pos, dim=-1)
401
+ x = torch.gather(cur_x, dim=-1, index=re_mask_indices)
402
+
403
+ TPF = sum_TPF / forward_count
404
+
405
+ return x, TPF
406
+
407
+
408
+
409
+ def main():
410
+ device = 'cuda'
411
+
412
+ model_path = "ReFusion"
413
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
414
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
415
+
416
+ prompt = "You are an expert Python programmer. Your task is to write a single Python function to solve the problem described below, and here is your task: Write a function to sum all amicable numbers from 1 to a specified number.\n\nDirectly after the '[BEGIN]' marker, you must write only the Python code for the function. Do not provide any explanations, comments, or introductory text. The function must include the 'def' line, its arguments, the function body, and a 'return' statement. Your code should pass these tests:\n\nassert amicable_numbers_sum(999)==504\nassert amicable_numbers_sum(9999)==31626\nassert amicable_numbers_sum(99)==0\n[BEGIN]\n"
417
+
418
+ m = [{"role": "user", "content": prompt}, ]
419
+ prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False, enable_thinking=True)
420
+
421
+ print(prompt)
422
+
423
+ input_ids = tokenizer(prompt)['input_ids']
424
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
425
+
426
+ out, TPF = generate_refusion(model, tokenizer, input_ids, gen_length=512, temperature=0., mask_id=151670, slot_size=4, model_path=model_path, serial_num_blocks=32, slot_threshold=0.6, token_threshold=0.3)
427
+ print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])
428
+ print("---------TPF:", TPF)
429
+
430
+
431
+ if __name__ == '__main__':
432
+ main()
433
+ ```
434
+