File size: 4,157 Bytes
8dabd6b
 
 
 
 
 
 
 
 
 
 
dd680c4
8dabd6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc46203
8dabd6b
bc46203
 
8dabd6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import tensorflow as tf
from translator import Translator
from utils import tokenizer_utils
from utils.preprocessing import input_processing, output_processing
from models.transformer import Transformer
from models.encoder import Encoder
from models.decoder import Decoder
from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network
from models.utils import masked_loss, masked_accuracy

def load_model_and_tokenizers(model_path="ckpts/en_vi_translation.keras"):
    """
    Load the pre-trained model and tokenizers.
    
    Args:
        model_path (str): Path to the pre-trained model file.
    
    Returns:
        model: Loaded TensorFlow model.
        en_tokenizer: English tokenizer.
        vi_tokenizer: Vietnamese tokenizer.
    """
    # Define custom objects for the model
    custom_objects = {
        "Transformer": Transformer,
        "Encoder": Encoder,
        "Decoder": Decoder,
        "EncoderLayer": EncoderLayer,
        "DecoderLayer": DecoderLayer,
        "MultiHeadAttention": MultiHeadAttention,
        "point_wise_feed_forward_network": point_wise_feed_forward_network,
        "masked_loss": masked_loss,
        "masked_accuracy": masked_accuracy,
    }
    
    # Load the model
    try:
        model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)
        print("Model loaded successfully.")
    except Exception as e:
        raise Exception(f"Failed to load model: {str(e)}")
    
    # Load tokenizers
    try:
        en_tokenizer, vi_tokenizer = tokenizer_utils.load_tokenizers()
        print("Tokenizers loaded successfully.")
    except Exception as e:
        raise Exception(f"Failed to load tokenizers: {str(e)}")
    
    return model, en_tokenizer, vi_tokenizer

def translate_sentence(sentence, model, en_tokenizer, vi_tokenizer):
    """
    Translate a single English sentence to Vietnamese.
    
    Args:
        sentence (str): English sentence to translate.
        model: Pre-trained translation model.
        en_tokenizer: English tokenizer.
        vi_tokenizer: Vietnamese tokenizer.
    
    Returns:
        str: Translated Vietnamese sentence.
    """
    if not sentence.strip():
        return "Please provide a valid sentence."
    
    # Initialize translator
    translator = Translator(en_tokenizer, vi_tokenizer, model)
    
    # Process and translate
    processed_sentence = input_processing(sentence)
    translated_text = translator(processed_sentence)
    translated_text = output_processing(translated_text)
    
    return translated_text

# Load model and tokenizers once at startup
try:
    model, en_tokenizer, vi_tokenizer = load_model_and_tokenizers()
except Exception as e:
    raise Exception(f"Initialization failed: {str(e)}")

# Define Gradio interface
def gradio_translate(sentence):
    """
    Gradio-compatible translation function.
    
    Args:
        sentence (str): Input English sentence.
    
    Returns:
        str: Translated Vietnamese sentence.
    """
    return translate_sentence(sentence, model, en_tokenizer, vi_tokenizer)

# Create Gradio interface
iface = gr.Interface(
    fn=gradio_translate,
    inputs=gr.Textbox(
        label="Enter English Sentence",
        placeholder="Type an English sentence to translate to Vietnamese...",
        lines=2
    ),
    outputs=gr.Textbox(
        label="Translated Vietnamese Sentence"
    ),
    title="English to Vietnamese Translation Transformer using TensorFlow",
    description=(
        "English to Vietnamese Translation Transformer using TensorFlow from Scratch"
        "Enter an English sentence to translate it to Vietnames. "
        "Example: 'Hello, world!'"
    ),
    examples=[
        [
            "For at least six centuries, residents along a lake in the mountains of central Japan "
            "have marked the depth of winter by celebrating the return of a natural phenomenon "
            "once revered as the trail of a wandering god."
        ],
        ["Hello, world!"],
        ["The sun is shining."]
    ]
)

# Launch the app
if __name__ == "__main__":
    iface.launch()