Update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import os
|
|
| 5 |
import httpx
|
| 6 |
|
| 7 |
logging.set_verbosity_error()
|
|
|
|
| 8 |
|
| 9 |
def download_argos_model(from_code, to_code):
|
| 10 |
import argostranslate.package
|
|
@@ -53,7 +54,8 @@ if model_name == 'Helsinki-NLP':
|
|
| 53 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 54 |
if model_name.startswith('t5'):
|
| 55 |
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 56 |
-
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
|
|
|
| 57 |
|
| 58 |
st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
|
| 59 |
submit_button = st.button("Translate")
|
|
@@ -77,7 +79,7 @@ if submit_button:
|
|
| 77 |
elif model_name.startswith('t5'):
|
| 78 |
prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
|
| 79 |
print(prompt)
|
| 80 |
-
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
| 81 |
# Perform translation
|
| 82 |
output_ids = model.generate(input_ids)
|
| 83 |
# Decode the translated text
|
|
@@ -104,8 +106,6 @@ if submit_button:
|
|
| 104 |
translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
|
| 105 |
except Exception as error:
|
| 106 |
translated_text = error
|
| 107 |
-
# download_argos_model(sl, tl)
|
| 108 |
-
# translated_text = argostranslate.translate.translate(input_text, sl, tl)
|
| 109 |
# Display the translated text
|
| 110 |
print(translated_text)
|
| 111 |
st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")
|
|
|
|
| 5 |
import httpx
|
| 6 |
|
| 7 |
logging.set_verbosity_error()
|
| 8 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 9 |
|
| 10 |
def download_argos_model(from_code, to_code):
|
| 11 |
import argostranslate.package
|
|
|
|
| 54 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 55 |
if model_name.startswith('t5'):
|
| 56 |
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 57 |
+
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
|
| 58 |
+
|
| 59 |
|
| 60 |
st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
|
| 61 |
submit_button = st.button("Translate")
|
|
|
|
| 79 |
elif model_name.startswith('t5'):
|
| 80 |
prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
|
| 81 |
print(prompt)
|
| 82 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
| 83 |
# Perform translation
|
| 84 |
output_ids = model.generate(input_ids)
|
| 85 |
# Decode the translated text
|
|
|
|
| 106 |
translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
|
| 107 |
except Exception as error:
|
| 108 |
translated_text = error
|
|
|
|
|
|
|
| 109 |
# Display the translated text
|
| 110 |
print(translated_text)
|
| 111 |
st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")
|