Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import datasets | |
| from langchain.docstore.document import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| import json | |
| from transformers.agents import Tool | |
| from langchain_core.vectorstores import VectorStore | |
| from transformers.agents import HfEngine, ReactJsonAgent | |
| # Install flash attention | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| "pip install git+https://github.com/huggingface/transformers.git#egg=transformers[agents]", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| # Install RAG dependencies | |
| subprocess.run( | |
| "pip install langchain sentence-transformers faiss-cpu", | |
| shell=True, | |
| ) | |
| import copy | |
| import spaces | |
| import time | |
| import torch | |
| from threading import Thread | |
| from typing import List, Dict, Union | |
| import urllib | |
| from PIL import Image | |
| import io | |
| import datasets | |
| import gradio as gr | |
| from transformers import AutoProcessor, TextIteratorStreamer | |
| from transformers import Idefics2ForConditionalGeneration | |
| DEVICE = torch.device("cuda") | |
| MODELS = { | |
| "idefics2-8b-chatty": Idefics2ForConditionalGeneration.from_pretrained( | |
| "HuggingFaceM4/idefics2-8b-chatty", | |
| # "Ali-C137/idefics2-8b-chatty-yalla", | |
| torch_dtype=torch.bfloat16, | |
| _attn_implementation="flash_attention_2", | |
| ).to(DEVICE), | |
| } | |
| PROCESSOR = AutoProcessor.from_pretrained( | |
| "HuggingFaceM4/idefics2-8b", | |
| # "Ali-C137/idefics2-8b-chatty-yalla", | |
| ) | |
| # Load the custom dataset | |
| knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") | |
| # Process the documents | |
| source_docs = [ | |
| Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) | |
| for doc in knowledge_base | |
| ] | |
| docs_processed = RecursiveCharacterTextSplitter(chunk_size=500).split_documents(source_docs)[:1000] | |
| # Create embeddings and vector store | |
| embedding_model = HuggingFaceEmbeddings("thenlper/gte-small") | |
| vectordb = FAISS.from_documents(documents=docs_processed, embedding=embedding_model) | |
| class RetrieverTool(Tool): | |
| name = "retriever" | |
| description = "Retrieves documents from the knowledge base that have the closest embeddings to the input query." | |
| inputs = { | |
| "query": { | |
| "type": "text", | |
| "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", | |
| }, | |
| "source": { | |
| "type": "text", | |
| "description": "", | |
| }, | |
| } | |
| output_type = "text" | |
| def __init__(self, vectordb: VectorStore, all_sources: str, **kwargs): | |
| super().__init__(**kwargs) | |
| self.vectordb = vectordb | |
| self.inputs["source"]["description"] = ( | |
| f"The source of the documents to search, as a str representation of a list. Possible values in the list are: {all_sources}. If this argument is not provided, all sources will be searched." | |
| ) | |
| def forward(self, query: str, source: str = None) -> str: | |
| assert isinstance(query, str), "Your search query must be a string" | |
| if source: | |
| if isinstance(source, str) and "[" not in str(source): # if the source is not representing a list | |
| source = [source] | |
| source = json.loads(str(source).replace("'", '"')) | |
| docs = self.vectordb.similarity_search(query, filter=({"source": source} if source else None), k=3) | |
| if len(docs) == 0: | |
| return "No documents found with this filtering. Try removing the source filter." | |
| return "Retrieved documents:\n\n" + "\n===Document===\n".join( | |
| [doc.page_content for doc in docs] | |
| ) | |
| from transformers.agents import HfEngine, ReactJsonAgent | |
| # Initialize the LLM engine and the agent with the retriever tool | |
| llm_engine = HfEngine("meta-llama/Meta-Llama-3-8B-Instruct") | |
| all_sources = list(set([doc.metadata["source"] for doc in docs_processed])) | |
| retriever_tool = RetrieverTool(vectordb, all_sources) | |
| agent = ReactJsonAgent(tools=[retriever_tool], llm_engine=llm_engine) | |
| # Should change this section for the finetuned model | |
| SYSTEM_PROMPT = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "You are YALLA, a personalized AI chatbot assistant designed to enhance the user's experience in Morocco. Your mission is to provide accurate, real-time, and culturally rich information to make their visit enjoyable and stress-free. You can handle text and image inputs, offering recommendations on transport, event schedules, dining, accommodations, and cultural experiences. You can also perform real-time web searches and use various APIs to assist users effectively. Always be respectful, polite, and inclusive, and strive to offer truthful and helpful responses.", | |
| }, | |
| ], | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Hello, I'm YALLA, your personalized AI assistant for exploring Morocco. How can I assist you today?", | |
| }, | |
| ], | |
| } | |
| ] | |
| examples_path = os.path.dirname(__file__) | |
| EXAMPLES = [ | |
| [ | |
| { | |
| "text": "For 2024, the interest expense is twice what it was in 2014, and the long-term debt is 10% higher than its 2015 level. Can you calculate the combined total of the interest and long-term debt for 2024?", | |
| "files": [f"{examples_path}/example_images/mmmu_example_2.png"], | |
| } | |
| ], | |
| [ | |
| { | |
| "text": "What's in the image?", | |
| "files": [f"{examples_path}/example_images/plant_bulb.webp"], | |
| } | |
| ], | |
| [ | |
| { | |
| "text": "Describe the image", | |
| "files": [f"{examples_path}/example_images/baguettes_guarding_paris.png"], | |
| } | |
| ], | |
| [ | |
| { | |
| "text": "Read what's written on the paper", | |
| "files": [f"{examples_path}/example_images/paper_with_text.png"], | |
| } | |
| ], | |
| ] | |
| # BOT_AVATAR = "IDEFICS_logo.png" | |
| BOT_AVATAR = "YALLA_logo.png" | |
| # Chatbot utils | |
| def turn_is_pure_media(turn): | |
| return turn[1] is None | |
| def load_image_from_url(url): | |
| with urllib.request.urlopen(url) as response: | |
| image_data = response.read() | |
| image_stream = io.BytesIO(image_data) | |
| image = Image.open(image_stream) | |
| return image | |
| def img_to_bytes(image_path): | |
| image = Image.open(image_path).convert(mode='RGB') | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="JPEG") | |
| img_bytes = buffer.getvalue() | |
| image.close() | |
| return img_bytes | |
| def format_user_prompt_with_im_history_and_system_conditioning( | |
| user_prompt, chat_history | |
| ) -> List[Dict[str, Union[List, str]]]: | |
| """ | |
| Produces the resulting list that needs to go inside the processor. | |
| It handles the potential image(s), the history and the system conditionning. | |
| """ | |
| resulting_messages = copy.deepcopy(SYSTEM_PROMPT) | |
| resulting_images = [] | |
| for resulting_message in resulting_messages: | |
| if resulting_message["role"] == "user": | |
| for content in resulting_message["content"]: | |
| if content["type"] == "image": | |
| resulting_images.append(load_image_from_url(content["image"])) | |
| # Format history | |
| for turn in chat_history: | |
| if not resulting_messages or ( | |
| resulting_messages and resulting_messages[-1]["role"] != "user" | |
| ): | |
| resulting_messages.append( | |
| { | |
| "role": "user", | |
| "content": [], | |
| } | |
| ) | |
| if turn_is_pure_media(turn): | |
| media = turn[0][0] | |
| resulting_messages[-1]["content"].append({"type": "image"}) | |
| resulting_images.append(Image.open(media)) | |
| else: | |
| user_utterance, assistant_utterance = turn | |
| resulting_messages[-1]["content"].append( | |
| {"type": "text", "text": user_utterance.strip()} | |
| ) | |
| resulting_messages.append( | |
| { | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": user_utterance.strip()}], | |
| } | |
| ) | |
| # Format current input | |
| if not user_prompt["files"]: | |
| resulting_messages.append( | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": user_prompt["text"]}], | |
| } | |
| ) | |
| else: | |
| # Choosing to put the image first (i.e. before the text), but this is an arbiratrary choice. | |
| resulting_messages.append( | |
| { | |
| "role": "user", | |
| "content": [{"type": "image"}] * len(user_prompt["files"]) | |
| + [{"type": "text", "text": user_prompt["text"]}], | |
| } | |
| ) | |
| resulting_images.extend([Image.open(path) for path in user_prompt["files"]]) | |
| return resulting_messages, resulting_images | |
| def extract_images_from_msg_list(msg_list): | |
| all_images = [] | |
| for msg in msg_list: | |
| for c_ in msg["content"]: | |
| if isinstance(c_, Image.Image): | |
| all_images.append(c_) | |
| return all_images | |
| def model_inference( | |
| user_prompt, | |
| chat_history, | |
| model_selector, | |
| decoding_strategy, | |
| temperature, | |
| max_new_tokens, | |
| repetition_penalty, | |
| top_p, | |
| ): | |
| if user_prompt["text"].strip() == "" and not user_prompt["files"]: | |
| gr.Error("Please input a query and optionally image(s).") | |
| if user_prompt["text"].strip() == "" and user_prompt["files"]: | |
| gr.Error("Please input a text query along the image(s).") | |
| streamer = TextIteratorStreamer( | |
| PROCESSOR.tokenizer, | |
| skip_prompt=True, | |
| timeout=5.0, | |
| ) | |
| # Common parameters to all decoding strategies | |
| generation_args = { | |
| "max_new_tokens": max_new_tokens, | |
| "repetition_penalty": repetition_penalty, | |
| "streamer": streamer, | |
| } | |
| assert decoding_strategy in [ | |
| "Greedy", | |
| "Top P Sampling", | |
| ] | |
| if decoding_strategy == "Greedy": | |
| generation_args["do_sample"] = False | |
| elif decoding_strategy == "Top P Sampling": | |
| generation_args["temperature"] = temperature | |
| generation_args["do_sample"] = True | |
| generation_args["top_p"] = top_p | |
| # Creating model inputs | |
| resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning( | |
| user_prompt=user_prompt, | |
| chat_history=chat_history, | |
| ) | |
| prompt = PROCESSOR.apply_chat_template(resulting_text, add_generation_prompt=True) | |
| inputs = PROCESSOR( | |
| text=prompt, | |
| images=resulting_images if resulting_images else None, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| generation_args.update(inputs) | |
| # Use the agent to perform RAG | |
| agent_output = agent.run(user_prompt["text"]) | |
| print("Agent output:", agent_output) | |
| # Stream the generated text | |
| thread = Thread( | |
| target=MODELS[model_selector].generate, | |
| kwargs=generation_args, | |
| ) | |
| thread.start() | |
| acc_text = "" | |
| for text_token in streamer: | |
| time.sleep(0.04) | |
| acc_text += text_token | |
| if acc_text.endswith("<end_of_utterance>"): | |
| acc_text = acc_text[:-18] | |
| yield acc_text | |
| print("Success - generated the following text:", acc_text) | |
| print("-----") | |
| FEATURES = datasets.Features( | |
| { | |
| "model_selector": datasets.Value("string"), | |
| "images": datasets.Sequence(datasets.Image(decode=True)), | |
| "conversation": datasets.Sequence({"User": datasets.Value("string"), "Assistant": datasets.Value("string")}), | |
| "decoding_strategy": datasets.Value("string"), | |
| "temperature": datasets.Value("float32"), | |
| "max_new_tokens": datasets.Value("int32"), | |
| "repetition_penalty": datasets.Value("float32"), | |
| "top_p": datasets.Value("int32"), | |
| } | |
| ) | |
| # Hyper-parameters for generation | |
| max_new_tokens = gr.Slider( | |
| minimum=8, | |
| maximum=1024, | |
| value=512, | |
| step=1, | |
| interactive=True, | |
| label="Maximum number of new tokens to generate", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=0.01, | |
| maximum=5.0, | |
| value=1.1, | |
| step=0.01, | |
| interactive=True, | |
| label="Repetition penalty", | |
| info="1.0 is equivalent to no penalty", | |
| ) | |
| decoding_strategy = gr.Radio( | |
| [ | |
| "Greedy", | |
| "Top P Sampling", | |
| ], | |
| value="Greedy", | |
| label="Decoding strategy", | |
| interactive=True, | |
| info="Higher values is equivalent to sampling more low-probability tokens.", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=0.4, | |
| step=0.1, | |
| visible=False, | |
| interactive=True, | |
| label="Sampling temperature", | |
| info="Higher values will produce more diverse outputs.", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.01, | |
| maximum=0.99, | |
| value=0.8, | |
| step=0.01, | |
| visible=False, | |
| interactive=True, | |
| label="Top P", | |
| info="Higher values is equivalent to sampling more low-probability tokens.", | |
| ) | |
| chatbot = gr.Chatbot( | |
| label="YALLA-Chatty", | |
| avatar_images=[None, BOT_AVATAR], | |
| height=450, | |
| ) | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.Markdown("# 🇲🇦 YALLA ") | |
| with gr.Row(elem_id="model_selector_row"): | |
| model_selector = gr.Dropdown( | |
| choices=MODELS.keys(), | |
| value=list(MODELS.keys())[0], | |
| interactive=True, | |
| show_label=False, | |
| container=False, | |
| label="Model", | |
| visible=False, | |
| ) | |
| decoding_strategy.change( | |
| fn=lambda selection: gr.Slider( | |
| visible=( | |
| selection | |
| in [ | |
| "contrastive_sampling", | |
| "beam_sampling", | |
| "Top P Sampling", | |
| "sampling_top_k", | |
| ] | |
| ) | |
| ), | |
| inputs=decoding_strategy, | |
| outputs=temperature, | |
| ) | |
| decoding_strategy.change( | |
| fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), | |
| inputs=decoding_strategy, | |
| outputs=top_p, | |
| ) | |
| gr.ChatInterface( | |
| fn=model_inference, | |
| chatbot=chatbot, | |
| examples=EXAMPLES, | |
| multimodal=False, | |
| cache_examples=False, | |
| additional_inputs=[ | |
| model_selector, | |
| decoding_strategy, | |
| temperature, | |
| max_new_tokens, | |
| repetition_penalty, | |
| top_p, | |
| ], | |
| ) | |
| demo.launch() | |