File size: 3,858 Bytes
69d6234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65aebfb
c3befb6
 
 
69d6234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3befb6
69d6234
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#MisterAI/Docker_Ollama
#app.py_01
#https://huggingface.co/spaces/MisterAI/Docker_Ollama/


import logging
import requests
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import gradio as gr
import threading
import subprocess


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)



# Cache pour stocker les modèles déjà chargés
loaded_models = {}


# Variable pour suivre l'état du bouton "Stop"
stop_flag = False


def get_model_list():
    url = "https://ollama.com/search"
    response = requests.get(url)

    # Vérifier si la requête a réussi
    if response.status_code == 200:
        # Extraire la liste des modèles depuis la page HTML
        model_list = [model.strip() for model in response.text.split('<span x-test-search-response-title>')[1:]]
        model_list = [model.split('</span>')[0] for model in model_list]
        return model_list
    else:
        logger.error(f"Erreur lors de la récupération de la liste des modèles : {response.status_code} - {response.text}")
        return []

def get_llm(model_name):
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    return Ollama(model=model_name, callback_manager=callback_manager)

class InputData(BaseModel):
    model_name: str
    input: str
    max_tokens: int = 256
    temperature: float = 0.7



def pull_model(model_name):
    try:
        # Exécuter la commande pour tirer le modèle
        subprocess.run(["ollama", "pull", model_name], check=True)
        logger.info(f"Model {model_name} pulled successfully.")
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to pull model {model_name}: {e}")
        raise

def check_and_load_model(model_name):
    # Vérifier si le modèle est déjà chargé
    if model_name in loaded_models:
        logger.info(f"Model {model_name} is already loaded.")
        return loaded_models[model_name]
    else:
        logger.info(f"Loading model {model_name}...")
        # Tirer le modèle si nécessaire
        pull_model(model_name)
        llm = get_llm(model_name)
        loaded_models[model_name] = llm
        return llm

        
        
        
# Interface Gradio
def gradio_interface(model_name, input, max_tokens, temperature, stop_button=None):
    global stop_flag
    stop_flag = False
    response = None  # Initialisez la variable response ici

    def worker():
        nonlocal response  # Utilisez nonlocal pour accéder à la variable response définie dans la fonction parente
        llm = check_and_load_model(model_name)
        response = llm(input, max_tokens=max_tokens, temperature=temperature)

    thread = threading.Thread(target=worker)
    thread.start()
    thread.join()

    if stop_flag:
        return "Processing stopped by the user."
    else:
        return response  # Maintenant, response est accessible ici

model_list = get_model_list()

#with gr.Blocks(theme=gr.themes.Glass()) as demo :
#with gr.Blocks() as demo :
#    demo = gr.Interface(
demo = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Dropdown(model_list, label="Select Model", value="mistral"),
        gr.Textbox(label="Input"),
        gr.Slider(minimum=1, maximum=2048, step=1, label="Max Tokens", value=256),
        gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.7),
        gr.Button(value="Stop", variant="stop")
    ],
    outputs=[
        gr.Textbox(label="Output")
#        gr.Button(value="Stop", variant="stop")
    ],        
    title="Ollama Demo"
)

def stop_processing():
    global stop_flag
    stop_flag = True



if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)