Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,128 +4,110 @@ from transformers import pipeline
|
|
| 4 |
import os
|
| 5 |
|
| 6 |
# --- App Configuration ---
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
Enter a topic or a
|
| 10 |
-
|
| 11 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
#
|
| 15 |
examples = [
|
| 16 |
-
["
|
| 17 |
-
["
|
| 18 |
-
["
|
| 19 |
-
["
|
| 20 |
-
["
|
| 21 |
-
["Alleviating stress"],
|
| 22 |
-
["Helping breathing, satisfaction"],
|
| 23 |
-
["Relieve Stress, Build Support"],
|
| 24 |
-
["The Relaxation Response"],
|
| 25 |
-
["Taking Deep Breaths"],
|
| 26 |
-
["Delete Not Helpful Thoughts"],
|
| 27 |
-
["Strengthen Helpful Thoughts"],
|
| 28 |
-
["Reprogram Pain and Stress Reactions"],
|
| 29 |
-
["How to Sleep Better and Find Joy"],
|
| 30 |
-
["Yoga for deep sleep"],
|
| 31 |
-
["Being a Happier and Healthier Person"],
|
| 32 |
-
["Relieve chronic pain by"],
|
| 33 |
-
["Use Mindfulness to Affect Well Being"],
|
| 34 |
-
["Build and Boost Mental Strength"],
|
| 35 |
-
["Spending Time Outdoors"],
|
| 36 |
-
["Daily Routine Tasks"],
|
| 37 |
-
["Eating and Drinking - Find Healthy Nutrition Habits"],
|
| 38 |
-
["Drinking - Find Reasons and Cut Back or Quit Entirely"],
|
| 39 |
-
["Feel better each day when you awake by"],
|
| 40 |
-
["Feel better physically by"],
|
| 41 |
-
["Practicing mindfulness each day"],
|
| 42 |
-
["Be happier by"],
|
| 43 |
-
["Meditation can improve health by"],
|
| 44 |
-
["Spending time outdoors helps to"],
|
| 45 |
-
["Stress is relieved by quieting your mind, getting exercise and time with nature"],
|
| 46 |
-
["Break the cycle of stress and anxiety"],
|
| 47 |
-
["Feel calm in stressful situations"],
|
| 48 |
-
["Deal with work pressure by"],
|
| 49 |
-
["Learn to reduce feelings of being overwhelmed"]
|
| 50 |
]
|
| 51 |
|
| 52 |
# --- Model Initialization ---
|
| 53 |
-
#
|
| 54 |
-
#
|
| 55 |
-
# Install dependencies: pip install gradio transformers torch accelerate
|
| 56 |
try:
|
| 57 |
print("Initializing models... This may take several minutes.")
|
| 58 |
|
| 59 |
-
#
|
| 60 |
-
#
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
-
generator2 = pipeline("text-generation", model="
|
| 66 |
-
print("
|
| 67 |
|
| 68 |
-
generator3 = pipeline("text-generation", model="
|
| 69 |
-
print("
|
| 70 |
|
| 71 |
-
print("All models loaded successfully!
|
| 72 |
|
| 73 |
except Exception as e:
|
| 74 |
-
print(f"Error loading models
|
| 75 |
-
print("
|
| 76 |
-
|
|
|
|
| 77 |
def failed_generator(prompt, **kwargs):
|
| 78 |
-
return [{'generated_text': "
|
| 79 |
generator1 = generator2 = generator3 = failed_generator
|
| 80 |
|
| 81 |
|
| 82 |
# --- App Logic ---
|
| 83 |
-
def
|
| 84 |
-
"""Generates text from the three loaded models."""
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
|
| 92 |
return out1, out2, out3
|
| 93 |
|
| 94 |
# --- Gradio Interface ---
|
| 95 |
-
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 96 |
-
gr.Markdown(f"<h1 style='text-align: center;'>{
|
| 97 |
-
gr.Markdown(
|
| 98 |
|
| 99 |
with gr.Row():
|
| 100 |
with gr.Column(scale=1):
|
| 101 |
input_area = gr.TextArea(
|
| 102 |
-
lines=
|
| 103 |
-
label="Your
|
| 104 |
-
placeholder="e.g., '
|
| 105 |
)
|
| 106 |
-
generate_button = gr.Button("
|
| 107 |
|
| 108 |
with gr.Column(scale=2):
|
| 109 |
with gr.Tabs():
|
| 110 |
-
with gr.TabItem("
|
| 111 |
-
gen1_output = gr.TextArea(label="
|
| 112 |
-
with gr.TabItem("
|
| 113 |
-
gen2_output = gr.TextArea(label="
|
| 114 |
-
with gr.TabItem("
|
| 115 |
-
gen3_output = gr.TextArea(label="
|
| 116 |
|
| 117 |
gr.Examples(
|
| 118 |
examples=examples,
|
| 119 |
inputs=input_area,
|
| 120 |
-
label="Example
|
| 121 |
)
|
| 122 |
|
| 123 |
generate_button.click(
|
| 124 |
-
fn=
|
| 125 |
inputs=input_area,
|
| 126 |
outputs=[gen1_output, gen2_output, gen3_output],
|
| 127 |
api_name="generate"
|
| 128 |
)
|
| 129 |
|
| 130 |
if __name__ == "__main__":
|
| 131 |
-
demo.launch()
|
|
|
|
| 4 |
import os
|
| 5 |
|
| 6 |
# --- App Configuration ---
|
| 7 |
+
TITLE = "✍️ AI Story Weaver"
|
| 8 |
+
DESCRIPTION = """
|
| 9 |
+
Enter a prompt, a topic, or the beginning of a story, and get three different continuations from powerful open-source AI models.
|
| 10 |
+
This app uses:
|
| 11 |
+
- **Mistral-7B-Instruct-v0.2**
|
| 12 |
+
- **Google's Gemma-7B-IT**
|
| 13 |
+
- **Meta's Llama-3-8B-Instruct**
|
| 14 |
+
|
| 15 |
+
**⚠️ Hardware Warning:** These are very large models. Loading them requires a powerful GPU with significant VRAM (ideally > 24GB).
|
| 16 |
+
The initial loading process may take several minutes. You will also need to install the `accelerate` library: `pip install accelerate`
|
| 17 |
"""
|
| 18 |
|
| 19 |
+
# --- Example Prompts for Storytelling ---
|
| 20 |
examples = [
|
| 21 |
+
["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."],
|
| 22 |
+
["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."],
|
| 23 |
+
["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."],
|
| 24 |
+
["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"],
|
| 25 |
+
["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
]
|
| 27 |
|
| 28 |
# --- Model Initialization ---
|
| 29 |
+
# This section loads the models. It requires significant hardware resources.
|
| 30 |
+
# `device_map="auto"` and `torch_dtype="auto"` help manage resources by using available GPUs and half-precision.
|
|
|
|
| 31 |
try:
|
| 32 |
print("Initializing models... This may take several minutes.")
|
| 33 |
|
| 34 |
+
# NOTE: For Llama-3, you may need to log in to Hugging Face and accept the license agreement.
|
| 35 |
+
# from huggingface_hub import login
|
| 36 |
+
# login("YOUR_HF_TOKEN")
|
| 37 |
+
|
| 38 |
+
generator1 = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype="auto", device_map="auto")
|
| 39 |
+
print("✅ Mistral-7B loaded.")
|
| 40 |
|
| 41 |
+
generator2 = pipeline("text-generation", model="google/gemma-7b-it", torch_dtype="auto", device_map="auto")
|
| 42 |
+
print("✅ Gemma-7B loaded.")
|
| 43 |
|
| 44 |
+
generator3 = pipeline("text-generation", model="meta-llama/Llama-3-8B-Instruct", torch_dtype="auto", device_map="auto")
|
| 45 |
+
print("✅ Llama-3-8B loaded.")
|
| 46 |
|
| 47 |
+
print("All models loaded successfully! 🎉")
|
| 48 |
|
| 49 |
except Exception as e:
|
| 50 |
+
print(f"--- 🚨 Error loading models ---")
|
| 51 |
+
print(f"Error: {e}")
|
| 52 |
+
print("Please ensure you have 'torch' and 'accelerate' installed, have sufficient VRAM, and are logged into Hugging Face if required.")
|
| 53 |
+
# Create a dummy function if models fail, so the app can still launch with an error message.
|
| 54 |
def failed_generator(prompt, **kwargs):
|
| 55 |
+
return [{'generated_text': "A model failed to load. Please check the console for errors. You may need more VRAM or need to accept model license terms on Hugging Face."}]
|
| 56 |
generator1 = generator2 = generator3 = failed_generator
|
| 57 |
|
| 58 |
|
| 59 |
# --- App Logic ---
|
| 60 |
+
def generate_stories(prompt: str) -> tuple[str, str, str]:
|
| 61 |
+
"""Generates text from the three loaded models based on the user's prompt."""
|
| 62 |
+
if not prompt:
|
| 63 |
+
return "Please enter a prompt to start.", "", ""
|
| 64 |
+
|
| 65 |
+
# We use 'max_new_tokens' to control the length of the generated story.
|
| 66 |
+
# Increased to 200 for more substantial story continuations.
|
| 67 |
+
params = {"max_new_tokens": 200, "do_sample": True, "temperature": 0.7, "top_p": 0.95}
|
| 68 |
|
| 69 |
+
# Generate from all three models
|
| 70 |
+
out1 = generator1(prompt, **params)[0]['generated_text']
|
| 71 |
+
out2 = generator2(prompt, **params)[0]['generated_text']
|
| 72 |
+
out3 = generator3(prompt, **params)[0]['generated_text']
|
| 73 |
|
| 74 |
return out1, out2, out3
|
| 75 |
|
| 76 |
# --- Gradio Interface ---
|
| 77 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
|
| 78 |
+
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
|
| 79 |
+
gr.Markdown(DESCRIPTION)
|
| 80 |
|
| 81 |
with gr.Row():
|
| 82 |
with gr.Column(scale=1):
|
| 83 |
input_area = gr.TextArea(
|
| 84 |
+
lines=5,
|
| 85 |
+
label="Your Story Prompt 👇",
|
| 86 |
+
placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'"
|
| 87 |
)
|
| 88 |
+
generate_button = gr.Button("Weave a Story ✨", variant="primary")
|
| 89 |
|
| 90 |
with gr.Column(scale=2):
|
| 91 |
with gr.Tabs():
|
| 92 |
+
with gr.TabItem("Mistral-7B"):
|
| 93 |
+
gen1_output = gr.TextArea(label="Mistral's Tale", interactive=False, lines=12)
|
| 94 |
+
with gr.TabItem("Gemma-7B"):
|
| 95 |
+
gen2_output = gr.TextArea(label="Gemma's Chronicle", interactive=False, lines=12)
|
| 96 |
+
with gr.TabItem("Llama-3-8B"):
|
| 97 |
+
gen3_output = gr.TextArea(label="Llama's Legend", interactive=False, lines=12)
|
| 98 |
|
| 99 |
gr.Examples(
|
| 100 |
examples=examples,
|
| 101 |
inputs=input_area,
|
| 102 |
+
label="Example Story Starters (Click to use)"
|
| 103 |
)
|
| 104 |
|
| 105 |
generate_button.click(
|
| 106 |
+
fn=generate_stories,
|
| 107 |
inputs=input_area,
|
| 108 |
outputs=[gen1_output, gen2_output, gen3_output],
|
| 109 |
api_name="generate"
|
| 110 |
)
|
| 111 |
|
| 112 |
if __name__ == "__main__":
|
| 113 |
+
demo.launch()
|