MisterAI commited on
Commit
69d6234
·
verified ·
1 Parent(s): d73bb73

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #MisterAI/Docker_Ollama
2
+ #app.py_01
3
+ #https://huggingface.co/spaces/MisterAI/Docker_Ollama/
4
+
5
+
6
+ import logging
7
+ import requests
8
+ from pydantic import BaseModel
9
+ from langchain_community.llms import Ollama
10
+ from langchain.callbacks.manager import CallbackManager
11
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
+ import gradio as gr
13
+ import threading
14
+ import subprocess
15
+
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+
22
+ # Cache pour stocker les modèles déjà chargés
23
+ loaded_models = {}
24
+
25
+
26
+ # Variable pour suivre l'état du bouton "Stop"
27
+ stop_flag = False
28
+
29
+
30
+ def get_model_list():
31
+ url = "https://ollama.com/search"
32
+ response = requests.get(url)
33
+
34
+ # Vérifier si la requête a réussi
35
+ if response.status_code == 200:
36
+ # Extraire la liste des modèles depuis la page HTML
37
+ model_list = [model.strip() for model in response.text.split('<span x-test-search-response-title>')[1:]]
38
+ model_list = [model.split('</span>')[0] for model in model_list]
39
+ return model_list
40
+ else:
41
+ logger.error(f"Erreur lors de la récupération de la liste des modèles : {response.status_code} - {response.text}")
42
+ return []
43
+
44
+ def get_llm(model_name):
45
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
46
+ return Ollama(model=model_name, callback_manager=callback_manager)
47
+
48
+ class InputData(BaseModel):
49
+ model_name: str
50
+ input: str
51
+ max_tokens: int = 256
52
+ temperature: float = 0.7
53
+
54
+
55
+
56
+ def pull_model(model_name):
57
+ try:
58
+ # Exécuter la commande pour tirer le modèle
59
+ subprocess.run(["ollama", "pull", model_name], check=True)
60
+ logger.info(f"Model {model_name} pulled successfully.")
61
+ except subprocess.CalledProcessError as e:
62
+ logger.error(f"Failed to pull model {model_name}: {e}")
63
+ raise
64
+
65
+ def check_and_load_model(model_name):
66
+ # Vérifier si le modèle est déjà chargé
67
+ if model_name in loaded_models:
68
+ logger.info(f"Model {model_name} is already loaded.")
69
+ return loaded_models[model_name]
70
+ else:
71
+ logger.info(f"Loading model {model_name}...")
72
+ # Tirer le modèle si nécessaire
73
+ pull_model(model_name)
74
+ llm = get_llm(model_name)
75
+ loaded_models[model_name] = llm
76
+ return llm
77
+
78
+
79
+
80
+
81
+ # Interface Gradio
82
+ def gradio_interface(model_name, input, max_tokens, temperature, stop_button=None):
83
+ global stop_flag
84
+ stop_flag = False
85
+ response = None # Initialisez la variable response ici
86
+
87
+ def worker():
88
+ nonlocal response # Utilisez nonlocal pour accéder à la variable response définie dans la fonction parente
89
+ llm = check_and_load_model(model_name)
90
+ response = llm(input, max_tokens=max_tokens, temperature=temperature)
91
+
92
+ thread = threading.Thread(target=worker)
93
+ thread.start()
94
+ thread.join()
95
+
96
+ if stop_flag:
97
+ return "Processing stopped by the user."
98
+ else:
99
+ return response # Maintenant, response est accessible ici
100
+
101
+ model_list = get_model_list()
102
+
103
+ with gr.Blocks(theme=gr.themes.Glass()) as demo :
104
+ demo = gr.Interface(
105
+ fn=gradio_interface,
106
+ inputs=[
107
+ gr.Dropdown(model_list, label="Select Model", value="mistral"),
108
+ gr.Textbox(label="Input"),
109
+ gr.Slider(minimum=1, maximum=2048, step=1, label="Max Tokens", value=256),
110
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.7),
111
+ gr.Button(value="Stop", variant="stop")
112
+ ],
113
+ outputs=[
114
+ gr.Textbox(label="Output")
115
+ # gr.Button(value="Stop", variant="stop")
116
+ ],
117
+ title="Ollama Demo"
118
+ )
119
+
120
+ def stop_processing():
121
+ global stop_flag
122
+ stop_flag = True
123
+
124
+
125
+
126
+ if __name__ == "__main__":
127
+ demo.launch(server_name="0.0.0.0", server_port=7860)
128
+
129
+