sss / app.py
phuoc60648's picture
file
1350112
import os
import cv2
import numpy as np
import pickle
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.applications import EfficientNetV2B0
from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.image import img_to_array
from tqdm import tqdm
import random
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tempfile
import traceback
from pathlib import Path
from huggingface_hub import hf_hub_download
import gradio as gr
from PIL import Image
import pickle
# -----------------------------
# Custom attention layers
# -----------------------------
class ChannelAttention(layers.Layer):
def __init__(self, ratio=8, **kwargs):
super(ChannelAttention, self).__init__(**kwargs)
self.ratio = ratio
def build(self, input_shape):
self.gap = layers.GlobalAveragePooling1D()
self.gmp = layers.GlobalMaxPooling1D()
self.shared_mlp = tf.keras.Sequential([
layers.Dense(units=1280 // self.ratio, activation='relu'),
layers.Dense(units=1280)
])
self.sigmoid = layers.Activation('sigmoid')
super(ChannelAttention, self).build(input_shape)
def call(self, inputs):
gap = self.gap(inputs)
gmp = self.gmp(inputs)
gap_mlp = self.shared_mlp(gap)
gmp_mlp = self.shared_mlp(gmp)
channel_attention = self.sigmoid(gap_mlp + gmp_mlp)
return inputs * tf.expand_dims(channel_attention, axis=1)
def get_config(self):
config = super(ChannelAttention, self).get_config()
config.update({'ratio': self.ratio})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
class SpatialAttention(layers.Layer):
def __init__(self, **kwargs):
super(SpatialAttention, self).__init__(**kwargs)
def build(self, input_shape):
self.conv = layers.Conv1D(1, kernel_size=3, padding='same', activation='sigmoid')
super(SpatialAttention, self).build(input_shape)
def call(self, inputs):
spatial_attention = self.conv(inputs)
return inputs * spatial_attention
def get_config(self):
return super(SpatialAttention, self).get_config()
@classmethod
def from_config(cls, config):
return cls(**config)
# -----------------------------
# Load model + tokenizer
# -----------------------------
def load_caption_model(model_path):
custom_objects = {
'ChannelAttention': ChannelAttention,
'SpatialAttention': SpatialAttention
}
model = load_model(model_path, custom_objects=custom_objects)
print("✅ Đã load model thành công!")
return model
def load_tokenizer_and_config(tokenizer_path, config_path):
with open(tokenizer_path, 'rb') as f:
tokenizer = pickle.load(f)
with open(config_path, 'rb') as f:
config = pickle.load(f)
return tokenizer, config['max_length'], config['vocab_size']
# -----------------------------
# Feature extractor - EfficientNetV2B0
# -----------------------------
def load_feature_extractor():
base_model = EfficientNetV2B0(include_top=False, weights='imagenet', pooling='avg')
return Model(inputs=base_model.input, outputs=base_model.output)
def extract_features_from_image(image_path, extractor):
image = cv2.imread(image_path)
if image is None:
print(f"❌ Không đọc được ảnh: {image_path}")
return None
image = cv2.resize(image, (224, 224))
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
image = efficientnet_preprocess(image)
feature = extractor.predict(image, verbose=0)
return feature
# -----------------------------
# Generate caption
# -----------------------------
def generate_caption(model, tokenizer, image_features, max_length):
in_text = 'startseq'
for _ in range(max_length):
sequence = tokenizer.texts_to_sequences([in_text])[0]
sequence = pad_sequences([sequence], maxlen=max_length)
yhat = model.predict([image_features, sequence], verbose=0)
yhat = np.argmax(yhat)
word = tokenizer.index_word.get(yhat)
if word is None or word == 'endseq':
break
in_text += ' ' + word
return in_text.replace('startseq ', '')
# -----------------------------
# Chạy test
# -----------------------------
MODEL_REPO = "slyviee/img_cap"
# Khởi tạo tài nguyên toàn cục khi app start
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.keras")
tokenizer_path = hf_hub_download(repo_id=MODEL_REPO, filename="tokenizer.pkl")
config_path = hf_hub_download(repo_id=MODEL_REPO, filename="model_config.pkl")
model = None
tokenizer = None
max_length = None
vocab_size = None
extractor = None
ready = False
startup_error = ""
def _startup():
global model, tokenizer, max_length, vocab_size, extractor, ready, startup_error
try:
# Kiểm tra sự tồn tại của các tệp cần thiết
missing = [p for p in [model_path, tokenizer_path, config_path] if not Path(p).exists()]
if missing:
startup_error = "Thiếu tệp: " + ", ".join(missing)
ready = False
return
print("🔄 Đang tải model...")
model = load_caption_model(model_path)
print("✅ Model đã được tải.")
print("🔄 Đang tải tokenizer và config...")
tokenizer, max_length, vocab_size = load_tokenizer_and_config(tokenizer_path, config_path)
print("✅ Tokenizer và config đã được tải.")
print("🔄 Đang tải feature extractor...")
extractor = load_feature_extractor()
print("✅ Feature extractor đã được tải.")
ready = True
except Exception as e:
startup_error = f"Khởi tạo lỗi: {e}\n{traceback.format_exc()}"
ready = False
def predict(pil_image: Image.Image):
if not ready:
return f"Hệ thống chưa sẵn sàng. {startup_error or 'Thiếu model/tokenizer/config.'}"
try:
# Lưu ảnh tạm để tái sử dụng hàm extract_features_from_image (đọc bằng cv2)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
pil_image.convert("RGB").save(tmp.name, format="JPEG")
tmp_path = tmp.name
features = extract_features_from_image(tmp_path, extractor)
os.unlink(tmp_path)
if features is None:
return "Không đọc được ảnh đầu vào."
caption = generate_caption(model, tokenizer, features, max_length)
return caption
except Exception as e:
return f"Lỗi trong quá trình dự đoán: {e}\n{traceback.format_exc()}"
DESCRIPTION = (
"Upload ảnh và nhận caption sinh ra bởi mô hình. "
)
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Ảnh vào"),
outputs=gr.Textbox(label="Caption"),
title="Image Captioning — Gradio",
description=DESCRIPTION,
)
if __name__ == '__main__':
_startup()
demo.launch()