Spaces:
Sleeping
Sleeping
''' | |
### | |
import os | |
import gradio as gr | |
import requests | |
from pinecone import Pinecone | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.llm import LLMChain | |
from langchain.llms.base import LLM | |
from typing import Optional, List, Mapping, Any | |
from langchain.embeddings import HuggingFaceEmbeddings | |
# ----------- 1. Custom LLM to call your LitServe endpoint ----------- | |
class LitServeLLM(LLM): | |
endpoint_url: str | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
payload = {"prompt": prompt} | |
response = requests.post(self.endpoint_url, json=payload) | |
if response.status_code == 200: | |
data = response.json() | |
return data.get("response", "").strip() | |
else: | |
raise ValueError(f"Request failed: {response.status_code} {response.text}") | |
@property | |
def _identifying_params(self) -> Mapping[str, Any]: | |
return {"endpoint_url": self.endpoint_url} | |
@property | |
def _llm_type(self) -> str: | |
return "litserve_llm" | |
# ----------- 2. Connect to Pinecone ----------- | |
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
index = pc.Index("rag-granite-index") | |
# ----------- 3. Load embedding model ----------- | |
embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
# ----------- 4. Function to get top context from Pinecone ----------- | |
def get_retrieved_context(query: str, top_k=3): | |
query_embedding = embeddings_model.embed_query(query) | |
results = index.query( | |
namespace="rag-ns", | |
vector=query_embedding, | |
top_k=top_k, | |
include_metadata=True | |
) | |
context_parts = [match['metadata']['text'] for match in results['matches']] | |
return "\n".join(context_parts) | |
# ----------- 5. Create LLMChain with your model ----------- | |
model = LitServeLLM( | |
endpoint_url="https://8001-01k2h9d9mervcmgfn66ybkpwvq.cloudspaces.litng.ai/predict" | |
) | |
prompt = PromptTemplate( | |
input_variables=["context", "question"], | |
template=""" | |
You are a smart assistant. Based on the provided context, answer the question in 1–2 lines only. | |
If the context has more details, summarize it concisely. | |
Context: | |
{context} | |
Question: {question} | |
Answer: | |
""" | |
) | |
llm_chain = LLMChain(llm=model, prompt=prompt) | |
# ----------- 6. Main RAG Function ----------- | |
def rag_pipeline(question): | |
try: | |
retrieved_context = get_retrieved_context(question) | |
response = llm_chain.invoke({ | |
"context": retrieved_context, | |
"question": question | |
})["text"].strip() | |
# Only keep what's after "Answer:" | |
if "Answer:" in response: | |
response = response.split("Answer:", 1)[-1].strip() | |
return response | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# ----------- 7. Gradio UI ----------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🧠 RAG Chatbot (Pinecone + LitServe)") | |
question_input = gr.Textbox(label="Ask your question here") | |
answer_output = gr.Textbox(label="Answer") | |
ask_button = gr.Button("Get Answer") | |
ask_button.click(rag_pipeline, inputs=question_input, outputs=answer_output) | |
if _name_ == "_main_": | |
demo.launch() | |
''' | |
''' | |
import os | |
import gradio as gr | |
import requests | |
import mlflow | |
import dagshub | |
from pinecone import Pinecone | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.llm import LLMChain | |
from langchain.llms.base import LLM | |
from typing import Optional, List, Mapping, Any | |
import time | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from dotenv import load_dotenv | |
from datetime import datetime | |
# Load environment variables | |
pinecone_api_key = os.environ["PINECONE_API_KEY"] | |
mlflow_tracking_uri = os.environ["MLFLOW_TRACKING_URI"] | |
# ----------- DagsHub & MLflow Setup ----------- | |
dagshub.init( | |
repo_owner='prathamesh.khade20', | |
repo_name='Maintenance_AI_website', | |
mlflow=True | |
) | |
mlflow.set_tracking_uri(mlflow_tracking_uri) | |
mlflow.set_experiment("Maintenance-RAG-Chatbot") | |
mlflow.langchain.autolog() | |
# Initialize MLflow run for app configuration | |
with mlflow.start_run(run_name=f"App-Config-{datetime.now().strftime('%Y%m%d-%H%M%S')}") as setup_run: | |
# Log environment configuration | |
mlflow.log_params({ | |
"pinecone_index": "rag-granite-index", | |
"embedding_model": "all-MiniLM-L6-v2", | |
"namespace": "rag-ns", | |
"top_k": 3, | |
"llm_endpoint": "https://8001-01k2h9d9mervcmgfn66ybkpwvq.cloudspaces.litng.ai/predict" | |
}) | |
# Log important files as artifacts | |
mlflow.log_text(""" | |
You are a smart assistant. Based on the provided context, answer the question in 1–2 lines only. | |
If the context has more details, summarize it concisely. | |
Context: | |
{context} | |
Question: {question} | |
Answer: | |
""", "artifacts/prompt_template.txt") | |
# ----------- 1. Custom LLM for LitServe endpoint ----------- | |
class LitServeLLM(LLM): | |
endpoint_url: str | |
@mlflow.trace | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
payload = {"prompt": prompt} | |
with mlflow.start_span("lit_serve_request"): | |
start_time = time.time() | |
response = requests.post(self.endpoint_url, json=payload) | |
latency = time.time() - start_time | |
mlflow.log_metric("lit_serve_latency", latency) | |
if response.status_code == 200: | |
data = response.json() | |
mlflow.log_metric("response_tokens", len(data.get("response", "").split())) | |
return data.get("response", "").strip() | |
else: | |
mlflow.log_metric("request_errors", 1) | |
error_info = { | |
"status_code": response.status_code, | |
"error": response.text, | |
"timestamp": datetime.now().isoformat() | |
} | |
mlflow.log_dict(error_info, "artifacts/error_log.json") | |
raise ValueError(f"Request failed: {response.status_code}") | |
@property | |
def _identifying_params(self) -> Mapping[str, Any]: | |
return {"endpoint_url": self.endpoint_url} | |
@property | |
def _llm_type(self) -> str: | |
return "litserve_llm" | |
# ----------- 2. Pinecone Connection ----------- | |
@mlflow.trace | |
def init_pinecone(): | |
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
return pc.Index("rag-granite-index") | |
index = init_pinecone() | |
# ----------- 3. Embedding Model ----------- | |
embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
# ----------- 4. Context Retrieval with Tracing ----------- | |
@mlflow.trace | |
def get_retrieved_context(query: str, top_k=3): | |
"""Retrieve context from Pinecone with performance tracing""" | |
with mlflow.start_span("embedding_generation"): | |
start_time = time.time() | |
query_embedding = embeddings_model.embed_query(query) | |
mlflow.log_metric("embedding_latency", time.time() - start_time) | |
with mlflow.start_span("pinecone_query"): | |
start_time = time.time() | |
results = index.query( | |
namespace="rag-ns", | |
vector=query_embedding, | |
top_k=top_k, | |
include_metadata=True | |
) | |
mlflow.log_metric("pinecone_latency", time.time() - start_time) | |
mlflow.log_metric("retrieved_chunks", len(results['matches'])) | |
context_parts = [match['metadata']['text'] for match in results['matches']] | |
return "\n".join(context_parts) | |
# ----------- 5. LLM Chain Setup ----------- | |
model = LitServeLLM( | |
endpoint_url="https://8001-01k2h9d9mervcmgfn66ybkpwvq.cloudspaces.litng.ai/predict" | |
) | |
prompt = PromptTemplate( | |
input_variables=["context", "question"], | |
template=""" | |
You are a smart assistant. Based on the provided context, answer the question in 1–2 lines only. | |
If the context has more details, summarize it concisely. | |
Context: | |
{context} | |
Question: {question} | |
Answer: | |
""" | |
) | |
llm_chain = LLMChain(llm=model, prompt=prompt) | |
# ----------- 6. RAG Pipeline with Full Tracing ----------- | |
@mlflow.trace | |
def rag_pipeline(question): | |
"""End-to-end RAG pipeline with MLflow tracing""" | |
try: | |
# Start a new nested run for each query | |
with mlflow.start_run(run_name=f"Query-{datetime.now().strftime('%H%M%S')}", nested=True): | |
mlflow.log_param("user_question", question) | |
# Retrieve context | |
retrieved_context = get_retrieved_context(question) | |
mlflow.log_text(retrieved_context, "artifacts/retrieved_context.txt") | |
# Generate response | |
start_time = time.time() | |
response = llm_chain.invoke({ | |
"context": retrieved_context, | |
"question": question | |
})["text"].strip() | |
# Clean response | |
if "Answer:" in response: | |
response = response.split("Answer:", 1)[-1].strip() | |
# Log metrics | |
mlflow.log_metric("response_latency", time.time() - start_time) | |
mlflow.log_metric("response_length", len(response)) | |
mlflow.log_text(response, "artifacts/response.txt") | |
return response | |
except Exception as e: | |
mlflow.log_metric("pipeline_errors", 1) | |
error_info = { | |
"error": str(e), | |
"question": question, | |
"timestamp": datetime.now().isoformat() | |
} | |
mlflow.log_dict(error_info, "artifacts/pipeline_errors.json") | |
return f"Error: {str(e)}" | |
# ----------- 7. Gradio UI with Enhanced Tracking ----------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🛠 Maintenance AI Assistant") | |
# Track additional UI metrics | |
usage_counter = gr.State(value=0) | |
session_start = gr.State(value=datetime.now().isoformat()) | |
question_input = gr.Textbox(label="Ask your maintenance question") | |
answer_output = gr.Textbox(label="AI Response") | |
ask_button = gr.Button("Get Answer") | |
feedback = gr.Radio(["Helpful", "Not Helpful"], label="Was this response helpful?") | |
def track_usage(question, count, session_start, feedback=None): | |
"""Wrapper to track usage metrics with feedback""" | |
count += 1 | |
# Start tracking context | |
with mlflow.start_run(run_name=f"User-Interaction-{count}", nested=True): | |
mlflow.log_param("question", question) | |
mlflow.log_param("session_start", session_start) | |
# Get response | |
response = rag_pipeline(question) | |
# Log feedback if provided | |
if feedback: | |
mlflow.log_param("user_feedback", feedback) | |
mlflow.log_metric("helpful_responses", 1 if feedback == "Helpful" else 0) | |
# Update metrics | |
mlflow.log_metric("total_queries", count) | |
return response, count, session_start | |
ask_button.click( | |
track_usage, | |
inputs=[question_input, usage_counter, session_start], | |
outputs=[answer_output, usage_counter, session_start] | |
) | |
feedback.change( | |
track_usage, | |
inputs=[question_input, usage_counter, session_start, feedback], | |
outputs=[answer_output, usage_counter, session_start] | |
) | |
if _name_ == "_main_": | |
# Log deployment information | |
with mlflow.start_run(run_name="Deployment-Info"): | |
mlflow.log_params({ | |
"app_version": "1.0.0", | |
"deployment_platform": "Lightning AI", | |
"deployment_time": datetime.now().isoformat(), | |
"code_version": os.getenv("GIT_COMMIT", "dev") | |
}) | |
# Start Gradio app | |
demo.launch() | |
''' | |
import torch | |
import mauve | |
from sacrebleu import corpus_bleu | |
from rouge_score import rouge_scorer | |
from bert_score import score | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline, AutoTokenizer | |
import re | |
from mauve import compute_mauve | |
import os | |
import gradio as gr | |
import requests | |
import mlflow | |
import dagshub | |
from pinecone import Pinecone | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.llm import LLMChain | |
from langchain.llms.base import LLM | |
from typing import Optional, List, Mapping, Any | |
import time | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from dotenv import load_dotenv | |
from datetime import datetime | |
# Load environment variables | |
load_dotenv() | |
pinecone_api_key = os.environ["PINECONE_API_KEY"] | |
mlflow_tracking_uri = os.environ["MLFLOW_TRACKING_URI"] | |
# ----------- DagsHub & MLflow Setup ----------- | |
dagshub.init( | |
repo_owner='prathamesh.khade20', | |
repo_name='Maintenance_AI_website', | |
mlflow=True | |
) | |
mlflow.set_tracking_uri(mlflow_tracking_uri) | |
mlflow.set_experiment("Maintenance-RAG-Chatbot") | |
mlflow.langchain.autolog() | |
# ----------- RAG Evaluator Class ----------- | |
class RAGEvaluator: | |
def __init__(self): | |
self.gpt2_model, self.gpt2_tokenizer = self.load_gpt2_model() | |
self.bias_pipeline = pipeline("zero-shot-classification", model="Hate-speech-CNERG/dehatebert-mono-english") | |
# Initialize tokenizer for text processing | |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
def load_gpt2_model(self): | |
model = GPT2LMHeadModel.from_pretrained('gpt2') | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
return model, tokenizer | |
def evaluate_bleu_rouge(self, candidates, references): | |
bleu_score = corpus_bleu(candidates, [references]).score | |
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) | |
rouge_scores = [scorer.score(ref, cand) for ref, cand in zip(references, candidates)] | |
rouge1 = sum([score['rouge1'].fmeasure for score in rouge_scores]) / len(rouge_scores) | |
rouge2 = sum([score['rouge2'].fmeasure for score in rouge_scores]) / len(rouge_scores) | |
rougeL = sum([score['rougeL'].fmeasure for score in rouge_scores]) / len(rouge_scores) | |
return bleu_score, rouge1, rouge2, rougeL | |
def evaluate_bert_score(self, candidates, references): | |
P, R, F1 = score(candidates, references, lang="en", model_type='bert-base-multilingual-cased') | |
return P.mean().item(), R.mean().item(), F1.mean().item() | |
def evaluate_perplexity(self, text): | |
encodings = self.gpt2_tokenizer(text, return_tensors='pt') | |
max_length = self.gpt2_model.config.n_positions | |
stride = 512 | |
lls = [] | |
for i in range(0, encodings.input_ids.size(1), stride): | |
begin_loc = max(i + stride - max_length, 0) | |
end_loc = min(i + stride, encodings.input_ids.size(1)) | |
trg_len = end_loc - i | |
input_ids = encodings.input_ids[:, begin_loc:end_loc] | |
target_ids = input_ids.clone() | |
target_ids[:, :-trg_len] = -100 | |
with torch.no_grad(): | |
outputs = self.gpt2_model(input_ids, labels=target_ids) | |
log_likelihood = outputs[0] * trg_len | |
lls.append(log_likelihood) | |
ppl = torch.exp(torch.stack(lls).sum() / end_loc) | |
return ppl.item() | |
def evaluate_diversity(self, texts): | |
# Use Hugging Face tokenizer instead of NLTK | |
all_tokens = [] | |
for text in texts: | |
tokens = self.tokenizer.tokenize(text) | |
all_tokens.extend(tokens) | |
# Create bigrams manually | |
unique_bigrams = set() | |
for i in range(len(all_tokens) - 1): | |
unique_bigrams.add((all_tokens[i], all_tokens[i+1])) | |
diversity_score = len(unique_bigrams) / len(all_tokens) if all_tokens else 0 | |
return diversity_score | |
def evaluate_racial_bias(self, text): | |
results = self.bias_pipeline([text], candidate_labels=["hate speech", "not hate speech"]) | |
bias_score = results[0]['scores'][results[0]['labels'].index('hate speech')] | |
return bias_score | |
def evaluate_meteor(self, candidates, references): | |
# Simple approximation of METEOR without NLTK | |
# This is a simplified version - consider using an external API for full METEOR | |
meteor_scores = [] | |
for ref, cand in zip(references, candidates): | |
ref_tokens = self.tokenizer.tokenize(ref) | |
cand_tokens = self.tokenizer.tokenize(cand) | |
# Calculate precision and recall | |
common_tokens = set(ref_tokens) & set(cand_tokens) | |
precision = len(common_tokens) / len(cand_tokens) if cand_tokens else 0 | |
recall = len(common_tokens) / len(ref_tokens) if ref_tokens else 0 | |
# F-measure with alpha=0.9 (METEOR default) | |
if precision + recall == 0: | |
f_score = 0 | |
else: | |
f_score = (10 * precision * recall) / (9 * precision + recall) | |
meteor_scores.append(f_score) | |
return sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0 | |
def evaluate_chrf(self, candidates, references): | |
# Simple character n-gram F-score approximation | |
chrf_scores = [] | |
for ref, cand in zip(references, candidates): | |
# Character 6-grams | |
ref_chars = list(ref) | |
cand_chars = list(cand) | |
ref_ngrams = set() | |
cand_ngrams = set() | |
# Create character 6-grams | |
for i in range(len(ref_chars) - 5): | |
ref_ngrams.add(tuple(ref_chars[i:i+6])) | |
for i in range(len(cand_chars) - 5): | |
cand_ngrams.add(tuple(cand_chars[i:i+6])) | |
common_ngrams = ref_ngrams & cand_ngrams | |
precision = len(common_ngrams) / len(cand_ngrams) if cand_ngrams else 0 | |
recall = len(common_ngrams) / len(ref_ngrams) if ref_ngrams else 0 | |
if precision + recall == 0: | |
chrf_score = 0 | |
else: | |
chrf_score = 2 * precision * recall / (precision + recall) | |
chrf_scores.append(chrf_score) | |
return sum(chrf_scores) / len(chrf_scores) if chrf_scores else 0 | |
def evaluate_readability(self, text): | |
# Simple readability metrics without textstat | |
words = re.findall(r'\b\w+\b', text.lower()) | |
sentences = re.split(r'[.!?]+', text) | |
num_words = len(words) | |
num_sentences = len([s for s in sentences if s.strip()]) | |
# Average word length | |
avg_word_length = sum(len(word) for word in words) / num_words if num_words else 0 | |
# Words per sentence | |
words_per_sentence = num_words / num_sentences if num_sentences else 0 | |
# Simplified Flesch Reading Ease approximation | |
flesch_ease = 206.835 - (1.015 * words_per_sentence) - (84.6 * avg_word_length) | |
# Simplified Flesch-Kincaid Grade Level approximation | |
flesch_grade = (0.39 * words_per_sentence) + (11.8 * avg_word_length) - 15.59 | |
return flesch_ease, flesch_grade | |
def evaluate_mauve(self, reference_texts, generated_texts): | |
out = compute_mauve( | |
p_text=reference_texts, | |
q_text=generated_texts, | |
device_id=0, | |
max_text_length=1024, | |
verbose=False | |
) | |
return out.mauve | |
def evaluate_all(self, question, response, reference): | |
candidates = [response] | |
references = [reference] | |
bleu, rouge1, rouge2, rougeL = self.evaluate_bleu_rouge(candidates, references) | |
bert_p, bert_r, bert_f1 = self.evaluate_bert_score(candidates, references) | |
perplexity = self.evaluate_perplexity(response) | |
diversity = self.evaluate_diversity(candidates) | |
racial_bias = self.evaluate_racial_bias(response) | |
meteor = self.evaluate_meteor(candidates, references) | |
chrf = self.evaluate_chrf(candidates, references) | |
flesch_ease, flesch_grade = self.evaluate_readability(response) | |
# Mauve requires multiple samples, so we'll handle it separately | |
mauve_score = self.evaluate_mauve(references, candidates) if len(references) > 1 else 0.0 | |
return { | |
"BLEU": bleu, | |
"ROUGE-1": rouge1, | |
"ROUGE-2": rouge2, | |
"ROUGE-L": rougeL, | |
"BERT_Precision": bert_p, | |
"BERT_Recall": bert_r, | |
"BERT_F1": bert_f1, | |
"Perplexity": perplexity, | |
"Diversity": diversity, | |
"Racial_Bias": racial_bias, | |
"MAUVE": mauve_score, | |
"METEOR": meteor, | |
"CHRF": chrf, | |
"Flesch_Reading_Ease": flesch_ease, | |
"Flesch_Kincaid_Grade": flesch_grade, | |
} | |
# Initialize the evaluator | |
evaluator = RAGEvaluator() | |
# Initialize MLflow run for app configuration | |
with mlflow.start_run(run_name=f"App-Config-{datetime.now().strftime('%Y%m%d-%H%M%S')}") as setup_run: | |
# Log environment configuration | |
mlflow.log_params({ | |
"pinecone_index": "rag-granite-index", | |
"embedding_model": "all-MiniLM-L6-v2", | |
"namespace": "rag-ns", | |
"top_k": 3, | |
"llm_endpoint": "https://8001-01k2h9d9mervcmgfn66ybkpwvq.cloudspaces.litng.ai/predict" | |
}) | |
# Log prompt template | |
mlflow.log_text(""" | |
You are a smart assistant. Based on the provided context, answer the question in 1–2 lines only. | |
If the context has more details, summarize it concisely. | |
Context: | |
{context} | |
Question: {question} | |
Answer: | |
""", "artifacts/prompt_template.txt") | |
# ----------- 1. Custom LLM for LitServe endpoint ----------- | |
class LitServeLLM(LLM): | |
endpoint_url: str | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
payload = {"prompt": prompt} | |
with mlflow.start_span("lit_serve_request"): | |
start_time = time.time() | |
response = requests.post(self.endpoint_url, json=payload) | |
latency = time.time() - start_time | |
mlflow.log_metric("lit_serve_latency", latency) | |
if response.status_code == 200: | |
data = response.json() | |
mlflow.log_metric("response_tokens", len(data.get("response", "").split())) | |
return data.get("response", "").strip() | |
else: | |
mlflow.log_metric("request_errors", 1) | |
error_info = { | |
"status_code": response.status_code, | |
"error": response.text, | |
"timestamp": datetime.now().isoformat() | |
} | |
mlflow.log_dict(error_info, "artifacts/error_log.json") | |
raise ValueError(f"Request failed: {response.status_code}") | |
def _identifying_params(self) -> Mapping[str, Any]: | |
return {"endpoint_url": self.endpoint_url} | |
def _llm_type(self) -> str: | |
return "litserve_llm" | |
# ----------- 2. Pinecone Connection ----------- | |
def init_pinecone(): | |
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
return pc.Index("rag-granite-index") | |
index = init_pinecone() | |
# ----------- 3. Embedding Model ----------- | |
embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
# ----------- 4. Context Retrieval with Tracing ----------- | |
def get_retrieved_context(query: str, top_k=3): | |
"""Retrieve context from Pinecone with performance tracing""" | |
with mlflow.start_span("embedding_generation"): | |
start_time = time.time() | |
query_embedding = embeddings_model.embed_query(query) | |
mlflow.log_metric("embedding_latency", time.time() - start_time) | |
with mlflow.start_span("pinecone_query"): | |
start_time = time.time() | |
results = index.query( | |
namespace="rag-ns", | |
vector=query_embedding, | |
top_k=top_k, | |
include_metadata=True | |
) | |
mlflow.log_metric("pinecone_latency", time.time() - start_time) | |
mlflow.log_metric("retrieved_chunks", len(results['matches'])) | |
context_parts = [match['metadata']['text'] for match in results['matches']] | |
return "\n".join(context_parts) | |
# ----------- 5. LLM Chain Setup ----------- | |
model = LitServeLLM( | |
endpoint_url="https://8001-01k2h9d9mervcmgfn66ybkpwvq.cloudspaces.litng.ai/predict" | |
) | |
prompt = PromptTemplate( | |
input_variables=["context", "question"], | |
template=""" | |
You are a smart assistant. Based on the provided context, answer the question in 1–2 lines only. | |
If the context has more details, summarize it concisely. | |
Context: | |
{context} | |
Question: {question} | |
Answer: | |
""" | |
) | |
llm_chain = LLMChain(llm=model, prompt=prompt) | |
# ----------- 6. RAG Pipeline with Full Tracing and Evaluation ----------- | |
def rag_pipeline(question): | |
"""End-to-end RAG pipeline with MLflow tracing and evaluation""" | |
try: | |
# Start a new nested run for each query | |
with mlflow.start_run(run_name=f"Query-{datetime.now().strftime('%H%M%S')}", nested=True): | |
mlflow.log_param("user_question", question) | |
# Retrieve context | |
retrieved_context = get_retrieved_context(question) | |
mlflow.log_text(retrieved_context, "artifacts/retrieved_context.txt") | |
# Generate response | |
start_time = time.time() | |
response = llm_chain.invoke({ | |
"context": retrieved_context, | |
"question": question | |
})["text"].strip() | |
# Clean response | |
if "Answer:" in response: | |
response = response.split("Answer:", 1)[-1].strip() | |
# Log metrics | |
mlflow.log_metric("response_latency", time.time() - start_time) | |
mlflow.log_metric("response_length", len(response)) | |
mlflow.log_text(response, "artifacts/response.txt") | |
# Evaluate the response against the retrieved context | |
evaluation_metrics = evaluator.evaluate_all( | |
question=question, | |
response=response, | |
reference=retrieved_context | |
) | |
# Log evaluation metrics to MLflow | |
for metric_name, metric_value in evaluation_metrics.items(): | |
mlflow.log_metric(metric_name, metric_value) | |
return response | |
except Exception as e: | |
mlflow.log_metric("pipeline_errors", 1) | |
error_info = { | |
"error": str(e), | |
"question": question, | |
"timestamp": datetime.now().isoformat() | |
} | |
mlflow.log_dict(error_info, "artifacts/pipeline_errors.json") | |
return f"Error: {str(e)}" | |
# ----------- 7. Gradio UI with Enhanced Tracking ----------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🛠 Maintenance AI Assistant") | |
# Track additional UI metrics | |
usage_counter = gr.State(value=0) | |
session_start = gr.State(value=datetime.now().isoformat()) | |
question_input = gr.Textbox(label="Ask your maintenance question") | |
answer_output = gr.Textbox(label="AI Response") | |
ask_button = gr.Button("Get Answer") | |
feedback = gr.Radio(["Helpful", "Not Helpful"], label="Was this response helpful?") | |
def track_usage(question, count, session_start, feedback=None): | |
"""Wrapper to track usage metrics with feedback""" | |
count += 1 | |
# Start tracking context | |
with mlflow.start_run(run_name=f"User-Interaction-{count}", nested=True): | |
mlflow.log_param("question", question) | |
mlflow.log_param("session_start", session_start) | |
# Get response | |
response = rag_pipeline(question) | |
# Log feedback if provided | |
if feedback: | |
mlflow.log_param("user_feedback", feedback) | |
mlflow.log_metric("helpful_responses", 1 if feedback == "Helpful" else 0) | |
# Update metrics | |
mlflow.log_metric("total_queries", count) | |
return response, count, session_start | |
ask_button.click( | |
track_usage, | |
inputs=[question_input, usage_counter, session_start], | |
outputs=[answer_output, usage_counter, session_start] | |
) | |
feedback.change( | |
lambda feedback, question, count, session_start: track_usage(question, count, session_start, feedback), | |
inputs=[feedback, question_input, usage_counter, session_start], | |
outputs=[answer_output, usage_counter, session_start] | |
) | |
if __name__ == "__main__": | |
# Log deployment information | |
with mlflow.start_run(run_name="Deployment-Info"): | |
mlflow.log_params({ | |
"app_version": "1.0.0", | |
"deployment_platform": "Lightning AI", | |
"deployment_time": datetime.now().isoformat(), | |
"code_version": os.getenv("GIT_COMMIT", "dev") | |
}) | |
# Start Gradio app | |
demo.launch() |