Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| #from streamlit_datalist import stDatalist | |
| from utils import convert_to_base64, convert_to_html | |
| import requests | |
| IP = '127.0.0.1' | |
| PORT= 8080 | |
| url = f'http://{IP}:{PORT}/predictions/model' | |
| headers = {'Content-Type': 'application/json'} | |
| st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide") | |
| #st.set_page_config(layout="wide") | |
| st.title("Multimodal Model on AWS Inf2") | |
| st.subheader("LLaVA-1.6-Mistral-7B") | |
| def upload_image(): | |
| image_list=["./images/view.jpg", | |
| "./images/cat.jpg", | |
| "./images/olympic.jpg", | |
| "./images/usa.jpg", | |
| "./images/box.jpg"] | |
| name_list=["view(https://llava-vl.github.io/static/images/view.jpg)", | |
| "cat", | |
| "paris 2024", | |
| "statue of liberty", | |
| "box(from my camera)"] | |
| images_all = dict(zip(name_list, image_list)) | |
| user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list) | |
| print(user_option) | |
| if user_option!="–Select–": | |
| image_names=[images_all[user_option]] | |
| else: | |
| image_names=[] | |
| st.text("OR") | |
| images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
| #print(images) | |
| # assert max number of images, e.g. 1 | |
| assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop()) | |
| if images or image_names: | |
| if images: | |
| image_names=[] | |
| # convert images to base64 | |
| images_b64 = [] | |
| for image in images+image_names: | |
| image_b64 = convert_to_base64(image) | |
| images_b64.append(image_b64) | |
| # display images in multiple columns | |
| cols = st.columns(len(images_b64)) ##only process first image | |
| for i, col in enumerate(cols): | |
| col.markdown(f"**Image {i+1}**") | |
| col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True) | |
| break #only process first image | |
| st.markdown("---") | |
| return images_b64[0] #only process first image | |
| st.stop() | |
| def ask_llm(prompt, byte_image): | |
| payload = { | |
| "prompt":prompt, | |
| "image": byte_image, | |
| "parameters": { | |
| "top_k": 100, | |
| "top_p": 0.1, | |
| "temperature": 0.2, | |
| } | |
| } | |
| response = requests.post(url, json=payload, headers=headers) | |
| return response.text | |
| def app(): | |
| st.markdown("---") | |
| c1, c2 = st.columns(2) | |
| with c2: | |
| image_b64 = upload_image() | |
| with c1: | |
| question = st.chat_input("Ask a question about this image") | |
| if not question: st.stop() | |
| with c1: | |
| with st.chat_message("question"): | |
| st.markdown(question, unsafe_allow_html=True) | |
| with st.spinner("Thinking..."): | |
| res = ask_llm(question, image_b64) | |
| with st.chat_message("response"): | |
| st.write(res) | |
| if __name__ == "__main__": | |
| app() |