Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| import pickle | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| from tensorflow.keras import layers | |
| from tensorflow.keras.models import load_model, Model | |
| from tensorflow.keras.applications import EfficientNetV2B0 | |
| from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from tqdm import tqdm | |
| import random | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| import tempfile | |
| import traceback | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| from PIL import Image | |
| import pickle | |
| # ----------------------------- | |
| # Custom attention layers | |
| # ----------------------------- | |
| class ChannelAttention(layers.Layer): | |
| def __init__(self, ratio=8, **kwargs): | |
| super(ChannelAttention, self).__init__(**kwargs) | |
| self.ratio = ratio | |
| def build(self, input_shape): | |
| self.gap = layers.GlobalAveragePooling1D() | |
| self.gmp = layers.GlobalMaxPooling1D() | |
| self.shared_mlp = tf.keras.Sequential([ | |
| layers.Dense(units=1280 // self.ratio, activation='relu'), | |
| layers.Dense(units=1280) | |
| ]) | |
| self.sigmoid = layers.Activation('sigmoid') | |
| super(ChannelAttention, self).build(input_shape) | |
| def call(self, inputs): | |
| gap = self.gap(inputs) | |
| gmp = self.gmp(inputs) | |
| gap_mlp = self.shared_mlp(gap) | |
| gmp_mlp = self.shared_mlp(gmp) | |
| channel_attention = self.sigmoid(gap_mlp + gmp_mlp) | |
| return inputs * tf.expand_dims(channel_attention, axis=1) | |
| def get_config(self): | |
| config = super(ChannelAttention, self).get_config() | |
| config.update({'ratio': self.ratio}) | |
| return config | |
| def from_config(cls, config): | |
| return cls(**config) | |
| class SpatialAttention(layers.Layer): | |
| def __init__(self, **kwargs): | |
| super(SpatialAttention, self).__init__(**kwargs) | |
| def build(self, input_shape): | |
| self.conv = layers.Conv1D(1, kernel_size=3, padding='same', activation='sigmoid') | |
| super(SpatialAttention, self).build(input_shape) | |
| def call(self, inputs): | |
| spatial_attention = self.conv(inputs) | |
| return inputs * spatial_attention | |
| def get_config(self): | |
| return super(SpatialAttention, self).get_config() | |
| def from_config(cls, config): | |
| return cls(**config) | |
| # ----------------------------- | |
| # Load model + tokenizer | |
| # ----------------------------- | |
| def load_caption_model(model_path): | |
| custom_objects = { | |
| 'ChannelAttention': ChannelAttention, | |
| 'SpatialAttention': SpatialAttention | |
| } | |
| model = load_model(model_path, custom_objects=custom_objects) | |
| print("✅ Đã load model thành công!") | |
| return model | |
| def load_tokenizer_and_config(tokenizer_path, config_path): | |
| with open(tokenizer_path, 'rb') as f: | |
| tokenizer = pickle.load(f) | |
| with open(config_path, 'rb') as f: | |
| config = pickle.load(f) | |
| return tokenizer, config['max_length'], config['vocab_size'] | |
| # ----------------------------- | |
| # Feature extractor - EfficientNetV2B0 | |
| # ----------------------------- | |
| def load_feature_extractor(): | |
| base_model = EfficientNetV2B0(include_top=False, weights='imagenet', pooling='avg') | |
| return Model(inputs=base_model.input, outputs=base_model.output) | |
| def extract_features_from_image(image_path, extractor): | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| print(f"❌ Không đọc được ảnh: {image_path}") | |
| return None | |
| image = cv2.resize(image, (224, 224)) | |
| image = img_to_array(image) | |
| image = np.expand_dims(image, axis=0) | |
| image = efficientnet_preprocess(image) | |
| feature = extractor.predict(image, verbose=0) | |
| return feature | |
| # ----------------------------- | |
| # Generate caption | |
| # ----------------------------- | |
| def generate_caption(model, tokenizer, image_features, max_length): | |
| in_text = 'startseq' | |
| for _ in range(max_length): | |
| sequence = tokenizer.texts_to_sequences([in_text])[0] | |
| sequence = pad_sequences([sequence], maxlen=max_length) | |
| yhat = model.predict([image_features, sequence], verbose=0) | |
| yhat = np.argmax(yhat) | |
| word = tokenizer.index_word.get(yhat) | |
| if word is None or word == 'endseq': | |
| break | |
| in_text += ' ' + word | |
| return in_text.replace('startseq ', '') | |
| # ----------------------------- | |
| # Chạy test | |
| # ----------------------------- | |
| MODEL_REPO = "slyviee/img_cap" | |
| # Khởi tạo tài nguyên toàn cục khi app start | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.keras") | |
| tokenizer_path = hf_hub_download(repo_id=MODEL_REPO, filename="tokenizer.pkl") | |
| config_path = hf_hub_download(repo_id=MODEL_REPO, filename="model_config.pkl") | |
| model = None | |
| tokenizer = None | |
| max_length = None | |
| vocab_size = None | |
| extractor = None | |
| ready = False | |
| startup_error = "" | |
| def _startup(): | |
| global model, tokenizer, max_length, vocab_size, extractor, ready, startup_error | |
| try: | |
| # Kiểm tra sự tồn tại của các tệp cần thiết | |
| missing = [p for p in [model_path, tokenizer_path, config_path] if not Path(p).exists()] | |
| if missing: | |
| startup_error = "Thiếu tệp: " + ", ".join(missing) | |
| ready = False | |
| return | |
| print("🔄 Đang tải model...") | |
| model = load_caption_model(model_path) | |
| print("✅ Model đã được tải.") | |
| print("🔄 Đang tải tokenizer và config...") | |
| tokenizer, max_length, vocab_size = load_tokenizer_and_config(tokenizer_path, config_path) | |
| print("✅ Tokenizer và config đã được tải.") | |
| print("🔄 Đang tải feature extractor...") | |
| extractor = load_feature_extractor() | |
| print("✅ Feature extractor đã được tải.") | |
| ready = True | |
| except Exception as e: | |
| startup_error = f"Khởi tạo lỗi: {e}\n{traceback.format_exc()}" | |
| ready = False | |
| def predict(pil_image: Image.Image): | |
| if not ready: | |
| return f"Hệ thống chưa sẵn sàng. {startup_error or 'Thiếu model/tokenizer/config.'}" | |
| try: | |
| # Lưu ảnh tạm để tái sử dụng hàm extract_features_from_image (đọc bằng cv2) | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: | |
| pil_image.convert("RGB").save(tmp.name, format="JPEG") | |
| tmp_path = tmp.name | |
| features = extract_features_from_image(tmp_path, extractor) | |
| os.unlink(tmp_path) | |
| if features is None: | |
| return "Không đọc được ảnh đầu vào." | |
| caption = generate_caption(model, tokenizer, features, max_length) | |
| return caption | |
| except Exception as e: | |
| return f"Lỗi trong quá trình dự đoán: {e}\n{traceback.format_exc()}" | |
| DESCRIPTION = ( | |
| "Upload ảnh và nhận caption sinh ra bởi mô hình. " | |
| ) | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Ảnh vào"), | |
| outputs=gr.Textbox(label="Caption"), | |
| title="Image Captioning — Gradio", | |
| description=DESCRIPTION, | |
| ) | |
| if __name__ == '__main__': | |
| _startup() | |
| demo.launch() |