File size: 5,387 Bytes
c8d3212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9e5e3
 
 
 
 
 
 
 
 
 
 
 
c8d3212
 
dc9e5e3
 
 
 
 
c8d3212
 
 
 
 
 
e2b23c0
c8d3212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#MisterAI/Docker_Ollama
#app.py_02
#https://huggingface.co/spaces/MisterAI/Docker_Ollama/

import logging
import requests
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import gradio as gr
import threading
import subprocess
from bs4 import BeautifulSoup

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Cache pour stocker les modèles déjà chargés
loaded_models = {}

# Variable pour suivre l'état du bouton "Stop"
stop_flag = False

def get_model_list():
    url = "https://ollama.com/search"
    response = requests.get(url)

    # Vérifier si la requête a réussi
    if response.status_code == 200:
        # Utiliser BeautifulSoup pour analyser le HTML
        soup = BeautifulSoup(response.text, 'html.parser')
        model_list = []

        # Trouver tous les éléments de modèle
        model_elements = soup.find_all('li', {'x-test-model': True})

        for model_element in model_elements:
            model_name = model_element.find('span', {'x-test-search-response-title': True}).text.strip()
            size_elements = model_element.find_all('span', {'x-test-size': True})

#            # Filtrer les modèles par taille
#            for size_element in size_elements:
#                size = size_element.text.strip()
#                if size.endswith('m'):
#                    # Tous les modèles en millions sont acceptés
#                    model_list.append(f"{model_name}:{size}")
#                elif size.endswith('b'):
#                    # Convertir les modèles en milliards en milliards
#                    size_value = float(size[:-1])
#                    if size_value <= 10:  # Filtrer les modèles <= 10 milliards de paramètres
#                        model_list.append(f"{model_name}:{size}")

            # Filtrer les modèles par taille
            for size_element in size_elements:
                size = size_element.text.strip().lower()  # Convertir en minuscules
                if 'x' in size:
                    # Exclure les modèles avec des tailles de type nXm ou nXb
                    continue
                elif size.endswith('m'):
                    # Tous les modèles en millions sont acceptés
                    model_list.append(f"{model_name}:{size}")
                elif size.endswith('b'):
                    # Convertir les modèles en milliards en milliards
                    size_value = float(size[:-1])
                    if size_value <= 10:  # Filtrer les modèles <= 10 milliards de paramètres
                        model_list.append(f"{model_name}:{size}")

        return model_list
    else:
        logger.error(f"Erreur lors de la récupération de la liste des modèles : {response.status_code} - {response.text}")
        return []

def get_llm(model_name):
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    return Ollama(model=model_name, callback_manager=callback_manager)

class InputData(BaseModel):
    model_name: str
    input: str
    max_tokens: int = 256
    temperature: float = 0.7

def pull_model(model_name):
    try:
        # Exécuter la commande pour tirer le modèle
        subprocess.run(["ollama", "pull", model_name], check=True)
        logger.info(f"Model {model_name} pulled successfully.")
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to pull model {model_name}: {e}")
        raise

def check_and_load_model(model_name):
    # Vérifier si le modèle est déjà chargé
    if model_name in loaded_models:
        logger.info(f"Model {model_name} is already loaded.")
        return loaded_models[model_name]
    else:
        logger.info(f"Loading model {model_name}...")
        # Tirer le modèle si nécessaire
        pull_model(model_name)
        llm = get_llm(model_name)
        loaded_models[model_name] = llm
        return llm

# Interface Gradio
def gradio_interface(model_name, input, max_tokens, temperature, stop_button=None):
    global stop_flag
    stop_flag = False
    response = None  # Initialisez la variable response ici

    def worker():
        nonlocal response  # Utilisez nonlocal pour accéder à la variable response définie dans la fonction parente
        llm = check_and_load_model(model_name)
        response = llm(input, max_tokens=max_tokens, temperature=temperature)

    thread = threading.Thread(target=worker)
    thread.start()
    thread.join()

    if stop_flag:
        return "Processing stopped by the user."
    else:
        return response  # Maintenant, response est accessible ici

model_list = get_model_list()

demo = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Dropdown(model_list, label="Select Model", value="mistral:7b"),
        gr.Textbox(label="Input"),
        gr.Slider(minimum=1, maximum=2048, step=1, label="Max Tokens", value=256),
        gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.7),
        gr.Button(value="Stop", variant="stop")
    ],
    outputs=[
        gr.Textbox(label="Output")
    ],
    title="Ollama Demo"
)

def stop_processing():
    global stop_flag
    stop_flag = True

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)