TiberiuCristianLeon's picture
Update app.py
059d3ee verified
raw
history blame
33.2 kB
import streamlit as st
import polars as pl
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, logging, AutoModelForCausalLM
import torch
import os
import httpx
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Language options and mappings
options = ["German", "Romanian", "English", "French", "Spanish", "Italian"]
favourite_langs = {"German": "de", "Romanian": "ro", "English": "en", "-----": "-----"}
df = pl.read_parquet("isolanguages.parquet")
non_empty_isos = df.slice(1).filter(pl.col("ISO639-1") != "").rows()
# all_langs = languagecodes.iso_languages_byname
all_langs = {iso[0]: (iso[1], iso[2], iso[3]) for iso in non_empty_isos} # {'Romanian': ('ro', 'rum', 'ron')}
# langs = list(favourite_langs.keys())
# langs.extend(list(all_langs.keys())) # Language options as list, add favourite languages first
# iso1_to_name = {codes[0]: lang for entry in all_langs for lang, codes in entry.items()} # {'ro': 'Romanian', 'de': 'German'}
iso1_to_name = {iso[1]: iso[0] for iso in non_empty_isos} # {'ro': 'Romanian', 'de': 'German'}
langs = {iso[0]: iso[1] for iso in non_empty_isos} # {'Romanian': 'ro', 'German': 'de'}
models = ["Helsinki-NLP", "Argos", "Google", "t5-base", "t5-small", "t5-large", "Unbabel/Tower-Plus-2B",
"Unbabel/TowerInstruct-Mistral-7B-v0.2", "winninghealth/WiNGPT-Babel-2"]
allmodels = ["Helsinki-NLP",
"Helsinki-NLP/opus-mt-tc-bible-big-mul-mul", "Helsinki-NLP/opus-mt-tc-bible-big-mul-deu_eng_nld",
"Helsinki-NLP/opus-mt-tc-bible-big-mul-deu_eng_fra_por_spa", "Helsinki-NLP/opus-mt-tc-bible-big-deu_eng_fra_por_spa-mul",
"Helsinki-NLP/opus-mt-tc-bible-big-roa-deu_eng_fra_por_spa", "Helsinki-NLP/opus-mt-tc-bible-big-deu_eng_fra_por_spa-roa",
"facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-3.3B",
"facebook/mbart-large-50-many-to-many-mmt", "facebook/mbart-large-50-one-to-many-mmt", "facebook/mbart-large-50-many-to-one-mmt",
"facebook/m2m100_418M", "facebook/m2m100_1.2B", "Lego-MT/Lego-MT",
"bigscience/mt0-small", "bigscience/mt0-base", "bigscience/mt0-large", "bigscience/mt0-xl",
"bigscience/bloomz-560m", "bigscience/bloomz-1b1", "bigscience/bloomz-1b7", "bigscience/bloomz-3b",
"t5-small", "t5-base", "t5-large",
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl",
"google/madlad400-3b-mt", "jbochi/madlad400-3b-mt",
"Argos", "Google",
"HuggingFaceTB/SmolLM3-3B", "winninghealth/WiNGPT-Babel-2",
"utter-project/EuroLLM-1.7B", "utter-project/EuroLLM-1.7B-Instruct",
"Unbabel/Tower-Plus-2B", "Unbabel/TowerInstruct-7B-v0.2", "Unbabel/TowerInstruct-Mistral-7B-v0.2",
"openGPT-X/Teuken-7B-instruct-commercial-v0.4", "openGPT-X/Teuken-7B-instruct-v0.6"
]
class Translators:
def __init__(self, model_name: str, sl: str, tl: str, input_text: str):
self.model_name = model_name
self.sl, self.tl = sl, tl
self.input_text = input_text
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def google(self):
url = os.environ['GCLIENT'] + f'sl={self.sl}&tl={self.tl}&q={self.input_text}'
response = httpx.get(url)
return response.json()[0][0][0]
@classmethod
def download_argos_model(cls, from_code, to_code):
import argostranslate.package
print('Downloading model', from_code, to_code)
# Download and install Argos Translate package
argostranslate.package.update_package_index()
available_packages = argostranslate.package.get_available_packages()
package_to_install = next(
filter(lambda x: x.from_code == from_code and x.to_code == to_code, available_packages)
)
argostranslate.package.install_from_path(package_to_install.download())
def argos(self):
import argostranslate.translate, argostranslate.package
try:
Translators.download_argos_model(self.sl, self.tl) # Download model
translated_text = argostranslate.translate.translate(self.input_text, self.sl, self.tl) # Translate
except StopIteration:
# packages_info = ', '.join(f"{pkg.get_description()}->{str(pkg.links)} {str(pkg.source_languages)}" for pkg in argostranslate.package.get_available_packages())
packages_info = ', '.join(f"{pkg.from_name} ({pkg.from_code}) -> {pkg.to_name} ({pkg.to_code})" for pkg in argostranslate.package.get_available_packages())
translated_text = f"No Argos model for {self.sl} to {self.tl}. Try other model or languages combination from the available Argos models: {packages_info}."
except Exception as error:
translated_text = error
return translated_text
def HelsinkiNLP_mulroa(self):
try:
pipe = pipeline("translation", model=self.model_name, device=self.device)
iso1to3 = {iso[1]: iso[3] for iso in non_empty_isos} # {'ro': 'ron'}
iso3tl = iso1to3.get(self.tl) # 'deu', 'ron', 'eng', 'fra'
translation = pipe(f'>>{iso3tl}<< {self.input_text}')
return translation[0]['translation_text'], f'Translated from {iso1_to_name[self.sl]} to {iso1_to_name[self.tl]} with {self.model_name}.'
except Exception as error:
return f"Error translating with model: {self.model_name}! Try other available language combination.", error
def HelsinkiNLP(self):
try: # Standard bilingual model
model_name = f"Helsinki-NLP/opus-mt-{self.sl}-{self.tl}"
pipe = pipeline("translation", model=model_name, device=self.device)
translation = pipe(self.input_text)
return translation[0]['translation_text'], f'Translated from {iso1_to_name[self.sl]} to {iso1_to_name[self.tl]} with {model_name}.'
except EnvironmentError:
try: # Tatoeba models
model_name = f"Helsinki-NLP/opus-tatoeba-{self.sl}-{self.tl}"
pipe = pipeline("translation", model=model_name, device=self.device)
translation = pipe(self.input_text)
return translation[0]['translation_text'], f'Translated from {iso1_to_name[self.sl]} to {iso1_to_name[self.tl]} with {model_name}.'
except EnvironmentError as error:
self.model_name = "Helsinki-NLP/opus-mt-tc-bible-big-mul-mul" # Last resort: try multi to multi
return self.HelsinkiNLP_mulroa()
except KeyError as error:
return f"Error: Translation direction {self.sl} to {self.tl} is not supported by Helsinki Translation Models", error
def LLaMAX(self):
pipe = pipeline("text-generation", model="LLaMAX/LLaMAX3-8B")
messages = [
{"role": "user", "content": f"Translate the following text from {self.sl} to {self.sl}: {self.input_text}"},
]
return pipe(messages)[0]["generated_text"]
def LegoMT(self):
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
model = M2M100ForConditionalGeneration.from_pretrained(self.model_name) # "Lego-MT/Lego-MT"
tokenizer = M2M100Tokenizer.from_pretrained(self.model_name)
tokenizer.src_lang = self.sl
encoded = tokenizer(self.input_text, return_tensors="pt")
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(self.tl))
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
def madlad(self):
model = T5ForConditionalGeneration.from_pretrained(self.model_name, device_map="auto")
tokenizer = T5Tokenizer.from_pretrained(self.model_name)
text = f"<2{self.tl}> {self.input_text}"
# input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
# outputs = model.generate(input_ids=input_ids)
# return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Use a pipeline as a high-level helper
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=self.sl, tgt_lang=self.tl)
translated_text = translator(text, max_length=512)
return translated_text[0]['translation_text']
def smollm(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model = AutoModelForCausalLM.from_pretrained(self.model_name)
prompt = f"""Translate the following {self.sl} text to {self.tl}, generating only the translated text and maintaining the original meaning and tone:
{self.input_text}
Translation:"""
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_length=len(inputs.input_ids[0]) + 150,
temperature=0.3,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
return response.split("Translation:")[-1].strip()
def flan(self):
tokenizer = T5Tokenizer.from_pretrained(self.model_name, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(self.model_name)
prompt = f"translate {self.sl} to {self.tl}: {self.input_text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model.generate(input_ids)
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
def tfive(self):
tokenizer = T5Tokenizer.from_pretrained(self.model_name)
model = T5ForConditionalGeneration.from_pretrained(self.model_name, device_map="auto")
prompt = f"translate {self.sl} to {self.tl}: {self.input_text}"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output_ids = model.generate(input_ids, max_length=512)
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
return translated_text
def mbart_many_to_many(self):
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
model = MBartForConditionalGeneration.from_pretrained(self.model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name)
# translate source to target
tokenizer.src_lang = languagecodes.mbart_large_languages[self.sl]
encoded = tokenizer(self.input_text, return_tensors="pt")
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.lang_code_to_id[languagecodes.mbart_large_languages[self.tl]]
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
def mbart_one_to_many(self):
# translate from English
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
model = MBartForConditionalGeneration.from_pretrained(self.model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name, src_lang="en_XX")
model_inputs = tokenizer(self.input_text, return_tensors="pt")
langid = languagecodes.mbart_large_languages[self.tl]
generated_tokens = model.generate(
**model_inputs,
forced_bos_token_id=tokenizer.lang_code_to_id[langid]
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
def mbart_many_to_one(self):
# translate to English
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
model = MBartForConditionalGeneration.from_pretrained(self.model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name)
tokenizer.src_lang = languagecodes.mbart_large_languages[self.sl]
encoded = tokenizer(self.input_text, return_tensors="pt")
generated_tokens = model.generate(**encoded)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
def mtom(self):
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
model = M2M100ForConditionalGeneration.from_pretrained(self.model_name)
tokenizer = M2M100Tokenizer.from_pretrained(self.model_name)
tokenizer.src_lang = self.sl
encoded = tokenizer(self.input_text, return_tensors="pt")
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(self.tl))
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
def bigscience(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self.input_text = self.input_text if self.input_text.endswith('.') else f'{self.input_text}.'
inputs = tokenizer.encode(f"Translate to {self.tl}: {self.input_text}", return_tensors="pt")
outputs = model.generate(inputs)
translation = tokenizer.decode(outputs[0])
translation = translation.replace('<pad> ', '').replace('</s>', '')
return translation
def bloomz(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model = AutoModelForCausalLM.from_pretrained(self.model_name)
self.input_text = self.input_text if self.input_text.endswith('.') else f'{self.input_text}.'
# inputs = tokenizer.encode(f"Translate from {self.sl} to {self.tl}: {self.input_text} Translation:", return_tensors="pt")
inputs = tokenizer.encode(f"Translate to {self.tl}: {self.input_text}", return_tensors="pt")
outputs = model.generate(inputs)
translation = tokenizer.decode(outputs[0])
translation = translation.replace('<pad> ', '').replace('</s>', '')
translation = translation.split('Translation:')[-1].strip() if 'Translation:' in translation else translation.strip()
return translation
def nllb(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name, src_lang=self.sl)
# model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name, device_map="auto", torch_dtype=torch.bfloat16)
model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=self.sl, tgt_lang=self.tl)
translated_text = translator(self.input_text, max_length=512)
return translated_text[0]['translation_text']
def wingpt(self):
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# input_json = '{"input_text": self.input_text}'
messages = [
{"role": "system", "content": f"Translate this to {self.tl} language"},
{"role": "user", "content": self.input_text}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.1
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
result = output.split('\n')[-1].strip() if '\n' in output else output.strip()
return result
def eurollm(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model = AutoModelForCausalLM.from_pretrained(self.model_name)
prompt = f"{self.sl}: {self.input_text} {self.tl}:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=512)
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(output)
# result = output.rsplit(f'{self.tl}:')[-1].strip() if f'{self.tl}:' in output else output.strip()
result = output.rsplit(f'{self.tl}:')[-1].strip() if '\n' in output or f'{self.tl}:' in output else output.strip()
return result
def eurollm_instruct(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model = AutoModelForCausalLM.from_pretrained(self.model_name)
text = f'<|im_start|>system\n<|im_end|>\n<|im_start|>user\nTranslate the following {self.sl} source text to {self.tl}:\n{self.sl}: {self.input_text} \n{self.tl}: <|im_end|>\n<|im_start|>assistant\n'
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=512)
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
if f'{self.tl}:' in output:
output = output.rsplit(f'{self.tl}:')[-1].strip().replace('assistant\n', '').strip()
return output
def teuken(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
use_fast=False,
trust_remote_code=True,
)
translation_prompt = f"Translate the following text from {self.sl} into {self.tl}: {self.input_text}"
messages = [{"role": "User", "content": translation_prompt}]
prompt_ids = tokenizer.apply_chat_template(messages, chat_template="EN", tokenize=True, add_generation_prompt=False, return_tensors="pt")
prediction = model.generate(
prompt_ids.to(model.device),
max_length=512,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
num_return_sequences=1,
)
translation = tokenizer.decode(prediction[0].tolist())
return translation
def unbabel(self):
pipe = pipeline("text-generation", model=self.model_name, torch_dtype=torch.bfloat16, device_map="auto")
messages = [{"role": "user",
"content": f"Translate the following text from {self.sl} into {self.tl}.\n{self.sl}: {self.input_text}.\n{self.tl}:"}]
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
tokenized_input = pipe.tokenizer(self.input_text, return_tensors="pt")
num_input_tokens = len(tokenized_input["input_ids"][0])
max_new_tokens = round(num_input_tokens + 0.75 * num_input_tokens)
outputs = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False)
translated_text = outputs[0]["generated_text"]
print(f"Input chars: {len(input_text)}", f"Input tokens: {num_input_tokens}", f"max_new_tokens: {max_new_tokens}",
"Chars to tokens ratio:", round(len(input_text) / num_input_tokens, 2), f"Raw translation: {translated_text}")
markers = ["<end_of_turn>", "<|im_end|>", "<|im_start|>assistant"] # , "\n"
for marker in markers:
if marker in translated_text:
translated_text = translated_text.split(marker)[1].strip()
translated_text = translated_text.replace('Answer:', '', 1).strip() if translated_text.startswith('Answer:') else translated_text
translated_text = translated_text.split("Translated text:")[0].strip() if "Translated text:" in translated_text else translated_text
split_translated_text = translated_text.split('\n', translated_text.count('\n'))
translated_text = '\n'.join(split_translated_text[:input_text.count('\n')+1])
return translated_text
def bergamot(model_name: str = 'deen', sl: str = 'de', tl: str = 'en', input_text: str = 'Hallo, mein Freund'):
try:
import bergamot
# input_text = [input_text] if isinstance(input_text, str) else input_text
config = bergamot.ServiceConfig(numWorkers=4)
service = bergamot.Service(config)
model = service.modelFromConfigPath(f"./{model_name}/bergamot.config.yml")
options = bergamot.ResponseOptions(alignment=False, qualityScores=False, HTML=False)
rawresponse = service.translate(model, bergamot.VectorString(input_text), options)
translated_text: str = next(iter(rawresponse)).target.text
message_text = f"Translated from {sl} to {tl} with Bergamot {model_name}."
except Exception as error:
response = error
return translated_text, message_text
def download_argos_model(from_code, to_code):
import argostranslate.package
print('Downloading model', from_code, to_code)
# Download and install Argos Translate package
argostranslate.package.update_package_index()
available_packages = argostranslate.package.get_available_packages()
package_to_install = next(
filter(
lambda x: x.from_code == from_code and x.to_code == to_code, available_packages
)
)
argostranslate.package.install_from_path(package_to_install.download())
def translate_text(model_name: str, s_language: str, t_language: str, input_text: str) -> tuple[str, str]:
"""
Translates the input text from the source language to the target language using a specified model.
Parameters:
input_text (str): The source text to be translated
s_language (str): The source language of the input text
t_language (str): The target language in which the input text is translated
model_name (str): The selected translation model name
Returns:
tuple:
translated_text(str): The input text translated to the selected target language
message_text(str): A descriptive message summarizing the translation process. Example: "Translated from English to German with Helsinki-NLP."
Example:
>>> translate_text("Hello world", "English", "German", "Helsinki-NLP")
("Hallo Welt", "Translated from English to German with Helsinki-NLP.")
"""
sl = all_langs[s_language][0]
tl = all_langs[t_language][0]
message_text = f'Translated from {s_language} to {t_language} with {model_name}'
translated_text = None
try:
if "-mul" in model_name.lower() or "mul-" in model_name.lower() or "-roa" in model_name.lower():
translated_text, message_text = Translators(model_name, sl, tl, input_text).HelsinkiNLP_mulroa()
elif model_name == "Helsinki-NLP":
translated_text, message_text = Translators(model_name, sl, tl, input_text).HelsinkiNLP()
elif model_name == 'Argos':
translated_text = Translators(model_name, sl, tl, input_text).argos()
elif model_name == 'Google':
translated_text = Translators(model_name, sl, tl, input_text).google()
elif "m2m" in model_name.lower():
translated_text = Translators(model_name, sl, tl, input_text).mtom()
elif "lego" in model_name.lower():
translated_text = Translators(model_name, sl, tl, input_text).LegoMT()
elif model_name.startswith('t5'):
translated_text = Translators(model_name, s_language, t_language, input_text).tfive()
elif 'flan' in model_name.lower():
translated_text = Translators(model_name, s_language, t_language, input_text).flan()
elif 'madlad' in model_name.lower():
translated_text = Translators(model_name, sl, tl, input_text).madlad()
elif 'mt0' in model_name.lower():
translated_text = Translators(model_name, s_language, t_language, input_text).bigscience()
elif 'bloomz' in model_name.lower():
translated_text = Translators(model_name, s_language, t_language, input_text).bloomz()
elif 'nllb' in model_name.lower():
nnlbsl, nnlbtl = languagecodes.nllb_language_codes[s_language], languagecodes.nllb_language_codes[t_language]
translated_text = Translators(model_name, nnlbsl, nnlbtl, input_text).nllb()
elif model_name == "facebook/mbart-large-50-many-to-many-mmt":
translated_text = Translators(model_name, s_language, t_language, input_text).mbart_many_to_many()
elif model_name == "facebook/mbart-large-50-one-to-many-mmt":
translated_text = Translators(model_name, s_language, t_language, input_text).mbart_one_to_many()
elif model_name == "facebook/mbart-large-50-many-to-one-mmt":
translated_text = Translators(model_name, s_language, t_language, input_text).mbart_many_to_one()
elif 'teuken' in model_name.lower():
translated_text = Translators(model_name, s_language, t_language, input_text).teuken()
elif model_name == "utter-project/EuroLLM-1.7B-Instruct":
translated_text = Translators(model_name, s_language, t_language, input_text).eurollm_instruct()
elif model_name == "utter-project/EuroLLM-1.7B":
translated_text = Translators(model_name, s_language, t_language, input_text).eurollm()
elif 'Unbabel' in model_name:
translated_text = Translators(model_name, s_language, t_language, input_text).unbabel()
elif model_name == "HuggingFaceTB/SmolLM3-3B":
translated_text = Translators(model_name, s_language, t_language, input_text).smollm()
elif model_name == "winninghealth/WiNGPT-Babel-2":
translated_text = Translators(model_name, s_language, t_language, input_text).wingpt()
elif "LLaMAX" in model_name:
translated_text = Translators(model_name, s_language, t_language, input_text).LLaMAX()
elif model_name == "Bergamot":
translated_text, message_text = Translators(model_name, s_language, t_language, input_text).bergamot()
except Exception as error:
translated_text = error
finally:
print(input_text, translated_text, message_text)
return translated_text, message_text
# App layout
st.header("Text Machine Translation", divider="gray", help="Text Machine Translation Streamlit App with Open Source Models")
input_text = st.text_area("Enter text to translate:", placeholder="Enter text to translate, maximum 512 characters!", max_chars=512)
# Initialize session state if not already set
if "sselected_language" not in st.session_state:
st.session_state["sselected_language"] = options[0]
if "tselected_language" not in st.session_state:
st.session_state["tselected_language"] = options[1]
if "model_name" not in st.session_state:
st.session_state["model_name"] = models[1]
# Model selection FIRST
model_name = st.selectbox("Select a model:", models,
index=models.index(st.session_state["model_name"]))
# Create columns for language selection
scol, swapcol, tcol = st.columns([3, 1, 3])
with scol:
sselected_language = st.selectbox("Source language:", options,
index=options.index(st.session_state["sselected_language"]))
with swapcol:
if st.button("🔄 Swap"):
st.session_state["model_name"] = model_name # Preserve model
st.session_state["sselected_language"], st.session_state["tselected_language"] = \
st.session_state["tselected_language"], st.session_state["sselected_language"]
st.rerun()
with tcol:
tselected_language = st.selectbox("Target language:", options,
index=options.index(st.session_state["tselected_language"]))
# Language codes
sl = langs[st.session_state["sselected_language"]]
tl = langs[st.session_state["tselected_language"]]
# Store selections
st.session_state["sselected_language"] = sselected_language
st.session_state["tselected_language"] = tselected_language
st.session_state["model_name"] = model_name
# magic = f'Selected language combination: {sselected_language} - {tselected_language}. Selected model: {model_name}'
# st.write(magic)
f'Selected language combination: {sselected_language} - {tselected_language}. Selected model: {model_name}'
with st.container(border=None, width="stretch", height="content", horizontal=False, horizontal_alignment="center", vertical_alignment="center", gap="small"):
submit_button = st.button("Translate")
# Show text area with placeholder
# translated_textarea = st.empty()
# message_textarea = st.empty()
# translated_textarea.text_area(":green[Translation:]", placeholder="Translation area", value='')
# message_textarea.text_input(":blue[Messages:]", placeholder="Messages area", value='')
# Handle the submit button click
if submit_button:
with st.spinner("Translating...", show_time=True):
translated_text, message = translate_text(model_name, sselected_language, tselected_language, input_text)
# if model_name.startswith('Helsinki-NLP'):
# # input_ids = tokenizer.encode(input_text, return_tensors='pt')
# # # Perform translation
# # output_ids = model.generate(input_ids)
# # # Decode the translated text
# # translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# # Use a pipeline as a high-level helper
# try:
# model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# pipe = pipeline("translation", model=model, tokenizer=tokenizer)
# except (EnvironmentError, OSError):
# model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# pipe = pipeline("translation", model=model, tokenizer=tokenizer)
# translation = pipe(input_text)
# translated_text = translation[0]['translation_text']
# elif model_name.startswith('t5'):
# tokenizer = T5Tokenizer.from_pretrained(model_name)
# model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
# prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
# input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
# # Perform translation
# output_ids = model.generate(input_ids)
# # Decode the translated text
# translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# elif 'Unbabel' in model_name:
# pipe = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")
# # We use the tokenizer’s chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
# messages = [{"role": "user",
# "content": f"Translate the following text from {sselected_language} into {tselected_language}.\n{sselected_language}: {input_text}.\n{tselected_language}:"}]
# prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
# outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
# translated_text = outputs[0]["generated_text"]
# start_marker = "<end_of_turn>"
# if start_marker in translated_text:
# translated_text = translated_text.split(start_marker)[1].strip()
# translated_text = translated_text.replace('Answer:', '').strip() if translated_text.startswith('Answer:') else translated_text
# Display the translated text
print(f"Translated from {sselected_language} to {tselected_language} using {model_name}.", input_text, translated_text)
# message = f"Translated from {sselected_language} to {tselected_language} using {model_name}."
# translated_textarea.text_area(":green[Translation:]", value=translated_text)
# message_textarea.text_input(":blue[Message:]", value=message)
st.text_area(":green[Translation:]", placeholder="Translation area", value=translated_text)
# st.success(message, icon=":material/check:") st.info(message, icon="ℹ️"), st.warning(message, icon=":material/warning:"), error(message, icon=":material/error:"), st.exception
st.info(message, icon=":material/info:")
# st.text_input(":blue[Messages:]", value=message)
# st.rerun()