File size: 2,602 Bytes
7c05e59
afc58c7
 
557ff8c
afc58c7
 
 
 
7c05e59
afc58c7
 
7c05e59
557ff8c
afc58c7
7c05e59
afc58c7
557ff8c
afc58c7
 
 
 
 
 
 
 
557ff8c
 
afc58c7
 
 
557ff8c
 
 
 
afc58c7
 
 
557ff8c
afc58c7
 
 
 
 
557ff8c
afc58c7
 
 
 
557ff8c
 
 
 
afc58c7
 
 
 
557ff8c
afc58c7
557ff8c
7c05e59
 
afc58c7
 
 
 
7c05e59
 
afc58c7
118787e
7c05e59
 
 
557ff8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import subprocess
import requests
import time
import logging
from langchain_community.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Cache for loaded models
loaded_models = {}

def check_ollama_running():
    """Wait until Ollama is fully ready."""
    url = "http://127.0.0.1:11434/api/tags"
    for _ in range(10):  # Try for ~10 seconds
        try:
            response = requests.get(url, timeout=2)
            if response.status_code == 200:
                logger.info("Ollama is running.")
                return True
        except requests.exceptions.RequestException:
            logger.warning("Waiting for Ollama to start...")
        time.sleep(2)
    raise RuntimeError("Ollama is not running. Please check the server.")

def pull_model(model_name):
    """Ensure the model is available before use."""
    if model_name in loaded_models:
        logger.info(f"Model {model_name} is already loaded.")
        return
    try:
        subprocess.run(["ollama", "pull", model_name], check=True)
        logger.info(f"Model {model_name} pulled successfully.")
        loaded_models[model_name] = True
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to pull model {model_name}: {e}")
        raise

def get_llm(model_name):
    """Get an LLM instance with streaming enabled."""
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    return Ollama(model=model_name, base_url="http://127.0.0.1:11434", callback_manager=callback_manager)

def query_model(model_name, prompt):
    """Generate responses from the model with streaming."""
    check_ollama_running()  # Ensure Ollama is ready
    pull_model(model_name)  # Make sure the model is available
    llm = get_llm(model_name)  # Load the model

    response = ""
    for token in llm.stream(prompt):
        response += token
        yield response  # Stream response in real-time

# Define Gradio interface
iface = gr.Interface(
    fn=query_model,
    inputs=[
        gr.Dropdown(["deepseek-r1:1.5b", "mistral:7b"], label="Select Model"),
        gr.Textbox(label="Enter your prompt")
    ],
    outputs="text",
    title="Ollama via LangChain & Gradio",
    description="Enter a prompt to interact with the Ollama-based model with streaming response.",
    flagging_dir="/app/flagged"
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)