Al-Alcoba-Inciarte commited on
Commit
557ff8c
·
verified ·
1 Parent(s): afc58c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -24
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import subprocess
3
- import time
4
  import requests
 
5
  import logging
6
  from langchain_community.llms import Ollama
7
  from langchain.callbacks.manager import CallbackManager
@@ -10,11 +10,11 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Global cache to store loaded models
14
  loaded_models = {}
15
 
16
- # Function to check if Ollama is running
17
  def check_ollama_running():
 
18
  url = "http://127.0.0.1:11434/api/tags"
19
  for _ in range(10): # Try for ~10 seconds
20
  try:
@@ -23,46 +23,40 @@ def check_ollama_running():
23
  logger.info("Ollama is running.")
24
  return True
25
  except requests.exceptions.RequestException:
26
- logger.warning("Ollama is not running yet. Retrying...")
27
- time.sleep(1)
28
  raise RuntimeError("Ollama is not running. Please check the server.")
29
 
30
- # Function to pull a model if not already available
31
  def pull_model(model_name):
 
 
 
 
32
  try:
33
- logger.info(f"Pulling model: {model_name}")
34
  subprocess.run(["ollama", "pull", model_name], check=True)
35
  logger.info(f"Model {model_name} pulled successfully.")
 
36
  except subprocess.CalledProcessError as e:
37
  logger.error(f"Failed to pull model {model_name}: {e}")
38
  raise
39
 
40
- # Function to get an LLM instance with streaming enabled
41
  def get_llm(model_name):
 
42
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
43
  return Ollama(model=model_name, base_url="http://127.0.0.1:11434", callback_manager=callback_manager)
44
 
45
- # Function to check and load a model
46
- def check_and_load_model(model_name):
47
- if model_name in loaded_models:
48
- logger.info(f"Model {model_name} is already loaded.")
49
- return loaded_models[model_name]
50
- pull_model(model_name) # Ensure the model is available
51
- llm = get_llm(model_name)
52
- loaded_models[model_name] = llm
53
- return llm
54
-
55
- # Function to handle Gradio input with streaming
56
  def query_model(model_name, prompt):
57
- check_ollama_running() # Ensure Ollama is running before making requests
58
- llm = check_and_load_model(model_name)
 
 
59
 
60
  response = ""
61
  for token in llm.stream(prompt):
62
  response += token
63
- yield response # Stream the response to Gradio in real-time
64
 
65
- # Define the Gradio interface
66
  iface = gr.Interface(
67
  fn=query_model,
68
  inputs=[
@@ -76,4 +70,4 @@ iface = gr.Interface(
76
  )
77
 
78
  if __name__ == "__main__":
79
- iface.launch(server_name="0.0.0.0", server_port=8080)
 
1
  import gradio as gr
2
  import subprocess
 
3
  import requests
4
+ import time
5
  import logging
6
  from langchain_community.llms import Ollama
7
  from langchain.callbacks.manager import CallbackManager
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Cache for loaded models
14
  loaded_models = {}
15
 
 
16
  def check_ollama_running():
17
+ """Wait until Ollama is fully ready."""
18
  url = "http://127.0.0.1:11434/api/tags"
19
  for _ in range(10): # Try for ~10 seconds
20
  try:
 
23
  logger.info("Ollama is running.")
24
  return True
25
  except requests.exceptions.RequestException:
26
+ logger.warning("Waiting for Ollama to start...")
27
+ time.sleep(2)
28
  raise RuntimeError("Ollama is not running. Please check the server.")
29
 
 
30
  def pull_model(model_name):
31
+ """Ensure the model is available before use."""
32
+ if model_name in loaded_models:
33
+ logger.info(f"Model {model_name} is already loaded.")
34
+ return
35
  try:
 
36
  subprocess.run(["ollama", "pull", model_name], check=True)
37
  logger.info(f"Model {model_name} pulled successfully.")
38
+ loaded_models[model_name] = True
39
  except subprocess.CalledProcessError as e:
40
  logger.error(f"Failed to pull model {model_name}: {e}")
41
  raise
42
 
 
43
  def get_llm(model_name):
44
+ """Get an LLM instance with streaming enabled."""
45
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
46
  return Ollama(model=model_name, base_url="http://127.0.0.1:11434", callback_manager=callback_manager)
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  def query_model(model_name, prompt):
49
+ """Generate responses from the model with streaming."""
50
+ check_ollama_running() # Ensure Ollama is ready
51
+ pull_model(model_name) # Make sure the model is available
52
+ llm = get_llm(model_name) # Load the model
53
 
54
  response = ""
55
  for token in llm.stream(prompt):
56
  response += token
57
+ yield response # Stream response in real-time
58
 
59
+ # Define Gradio interface
60
  iface = gr.Interface(
61
  fn=query_model,
62
  inputs=[
 
70
  )
71
 
72
  if __name__ == "__main__":
73
+ iface.launch(server_name="0.0.0.0", server_port=7860)