Spaces:
Runtime error
Runtime error
Commit
·
d38f5f1
1
Parent(s):
422252e
Fix a bug in generate_text
Browse files
app.py
CHANGED
|
@@ -34,12 +34,10 @@ def generate_text(prompt: str,
|
|
| 34 |
temperature: float = 0.5,
|
| 35 |
top_p: float = 0.95,
|
| 36 |
top_k: int = 50) -> str:
|
| 37 |
-
|
| 38 |
# Encode the prompt
|
| 39 |
inputs = tokenizer([prompt],
|
| 40 |
return_tensors='pt',
|
| 41 |
add_special_tokens=False).to(DEVICE)
|
| 42 |
-
|
| 43 |
# Prepare arguments for generation
|
| 44 |
input_length = inputs["input_ids"].shape[-1]
|
| 45 |
max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
|
|
@@ -56,8 +54,8 @@ def generate_text(prompt: str,
|
|
| 56 |
skip_prompt=True,
|
| 57 |
skip_special_tokens=True)
|
| 58 |
generation_kwargs = dict(
|
| 59 |
-
inputs
|
| 60 |
-
streamer=
|
| 61 |
max_new_tokens=max_new_tokens,
|
| 62 |
do_sample=True,
|
| 63 |
top_p=top_p,
|
|
@@ -65,12 +63,10 @@ def generate_text(prompt: str,
|
|
| 65 |
temperature=temperature,
|
| 66 |
num_beams=1,
|
| 67 |
)
|
| 68 |
-
|
| 69 |
# Generate text
|
| 70 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 71 |
thread.start()
|
| 72 |
-
|
| 73 |
-
generated_text = ""
|
| 74 |
for new_text in streamer:
|
| 75 |
generated_text += new_text
|
| 76 |
return generated_text
|
|
|
|
| 34 |
temperature: float = 0.5,
|
| 35 |
top_p: float = 0.95,
|
| 36 |
top_k: int = 50) -> str:
|
|
|
|
| 37 |
# Encode the prompt
|
| 38 |
inputs = tokenizer([prompt],
|
| 39 |
return_tensors='pt',
|
| 40 |
add_special_tokens=False).to(DEVICE)
|
|
|
|
| 41 |
# Prepare arguments for generation
|
| 42 |
input_length = inputs["input_ids"].shape[-1]
|
| 43 |
max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
|
|
|
|
| 54 |
skip_prompt=True,
|
| 55 |
skip_special_tokens=True)
|
| 56 |
generation_kwargs = dict(
|
| 57 |
+
**inputs,
|
| 58 |
+
streamer=streamer,
|
| 59 |
max_new_tokens=max_new_tokens,
|
| 60 |
do_sample=True,
|
| 61 |
top_p=top_p,
|
|
|
|
| 63 |
temperature=temperature,
|
| 64 |
num_beams=1,
|
| 65 |
)
|
|
|
|
| 66 |
# Generate text
|
| 67 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 68 |
thread.start()
|
| 69 |
+
generated_text = prompt
|
|
|
|
| 70 |
for new_text in streamer:
|
| 71 |
generated_text += new_text
|
| 72 |
return generated_text
|