karths commited on
Commit
1a586b9
·
verified ·
1 Parent(s): ee0bb3d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -0
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from huggingface_hub import login, HfFolder
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, TextIteratorStreamer
8
+ from scipy.special import softmax
9
+ import logging
10
+ import spaces
11
+ from threading import Thread
12
+ from collections.abc import Iterator
13
+ import csv
14
+
15
+ # Increase CSV field size limit
16
+ csv.field_size_limit(1000000)
17
+
18
+ # Setup logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
20
+
21
+ # Set a seed for reproducibility
22
+ seed = 42
23
+ np.random.seed(seed)
24
+ random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ if torch.cuda.is_available():
27
+ torch.cuda.manual_seed_all(seed)
28
+
29
+ # Login to Hugging Face
30
+ token = os.getenv("hf_token")
31
+ HfFolder.save_token(token)
32
+ login(token)
33
+
34
+ model_paths = [
35
+ 'karths/binary_classification_train_port',
36
+ 'karths/binary_classification_train_perf',
37
+ "karths/binary_classification_train_main",
38
+ "karths/binary_classification_train_secu",
39
+ "karths/binary_classification_train_reli",
40
+ "karths/binary_classification_train_usab",
41
+ "karths/binary_classification_train_comp"
42
+ ]
43
+
44
+ quality_mapping = {
45
+ 'binary_classification_train_port': 'Portability',
46
+ 'binary_classification_train_main': 'Maintainability',
47
+ 'binary_classification_train_secu': 'Security',
48
+ 'binary_classification_train_reli': 'Reliability',
49
+ 'binary_classification_train_usab': 'Usability',
50
+ 'binary_classification_train_perf': 'Performance',
51
+ 'binary_classification_train_comp': 'Compatibility'
52
+ }
53
+
54
+ # Pre-load models and tokenizer for quality prediction
55
+ tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
56
+ models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths}
57
+
58
+ def get_quality_name(model_name):
59
+ return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
60
+
61
+
62
+ def model_prediction(model, text, device):
63
+ model.to(device)
64
+ model.eval()
65
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
66
+ inputs = {k: v.to(device) for k, v in inputs.items()}
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ logits = outputs.logits
70
+ probs = softmax(logits.cpu().numpy(), axis=1)
71
+ avg_prob = np.mean(probs[:, 1])
72
+ model.to("cpu")
73
+ return avg_prob
74
+
75
+ # --- Llama 3.2 3B Model Setup ---
76
+ LLAMA_MAX_MAX_NEW_TOKENS = 512
77
+ LLAMA_DEFAULT_MAX_NEW_TOKENS = 512
78
+ LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
79
+ llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
80
+ llama_model_id = "meta-llama/Llama-3.2-1B-Instruct"
81
+ llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
82
+ llama_model = AutoModelForCausalLM.from_pretrained(
83
+ llama_model_id,
84
+ device_map="auto",
85
+ torch_dtype=torch.bfloat16,
86
+ )
87
+ llama_model.eval()
88
+
89
+ if llama_tokenizer.pad_token is None:
90
+ llama_tokenizer.pad_token = llama_tokenizer.eos_token
91
+
92
+ def llama_generate(
93
+ message: str,
94
+ max_new_tokens: int = LLAMA_DEFAULT_MAX_NEW_TOKENS,
95
+ temperature: float = 0.3,
96
+ top_p: float = 0.9,
97
+ top_k: int = 50,
98
+ repetition_penalty: float = 1.2,
99
+ ) -> str:
100
+
101
+ inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
102
+
103
+ if inputs.input_ids.shape[1] > LLAMA_MAX_INPUT_TOKEN_LENGTH:
104
+ inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
105
+ gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
106
+
107
+ with torch.no_grad():
108
+ generate_ids = llama_model.generate(
109
+ **inputs,
110
+ max_new_tokens=max_new_tokens,
111
+ do_sample=True,
112
+ top_p=top_p,
113
+ top_k=top_k,
114
+ temperature=temperature,
115
+ num_beams=1,
116
+ repetition_penalty=repetition_penalty,
117
+ pad_token_id=llama_tokenizer.pad_token_id,
118
+ eos_token_id=llama_tokenizer.eos_token_id,
119
+
120
+ )
121
+ output_text = llama_tokenizer.decode(generate_ids[0], skip_special_tokens=True)
122
+ torch.cuda.empty_cache()
123
+ return output_text
124
+
125
+
126
+ def generate_explanation(issue_text, top_quality):
127
+ """Generates an explanation for the *single* top quality above threshold."""
128
+ if not top_quality:
129
+ return "<div style='color: red;'>No explanation available as no quality tags met the threshold.</div>"
130
+
131
+ quality_name = top_quality[0][0] # Get the name of the top quality
132
+
133
+ prompt = f"""
134
+ Given the following issue description:
135
+ ---
136
+ {issue_text}
137
+ ---
138
+ Explain why this issue might be classified as a **{quality_name}** issue. Provide a concise explanation, relating it back to the issue description. Keep the explanation short and concise. Do not repeat the prompt or include any preamble in your response - just provide the explanation directly.
139
+ """
140
+ try:
141
+ explanation = llama_generate(prompt)
142
+ # Extract only the model's explanation, removing any prompt repetition
143
+ # This typically removes any preamble like "Here's why this is a [quality] issue:"
144
+ cleaned_explanation = explanation.split("---")[-1].strip()
145
+ if cleaned_explanation.lower().startswith(quality_name.lower()):
146
+ cleaned_explanation = cleaned_explanation[len(quality_name):].strip()
147
+ if cleaned_explanation.startswith(":"):
148
+ cleaned_explanation = cleaned_explanation[1:].strip()
149
+
150
+ # Format for better readability
151
+ formatted_explanation = f"<div class='explanation-box'><p><b>Why this is a {quality_name} issue:</b></p><p>{cleaned_explanation}</p></div>"
152
+ return formatted_explanation
153
+ except Exception as e:
154
+ logging.error(f"Error during Llama generation: {e}")
155
+ return "<div style='color: red;'>An error occurred while generating the explanation.</div>"
156
+
157
+
158
+ # @spaces.GPU(duration=60)
159
+ def main_interface(text):
160
+ if not text.strip():
161
+ return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
162
+
163
+ if len(text) < 30:
164
+ return "<div style='color: red;'>Text is less than 30 characters.</div>", "", ""
165
+
166
+ device = "cuda" if torch.cuda.is_available() else "cpu"
167
+ results = []
168
+ for model_path, model in models.items():
169
+ quality_name = get_quality_name(model_path)
170
+ avg_prob = model_prediction(model, text, device)
171
+ if avg_prob >= 0.95: # Keep *all* results above the threshold
172
+ results.append((quality_name, avg_prob))
173
+ logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
174
+
175
+ if not results:
176
+ return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold.</div>", "", ""
177
+
178
+ # Sort and get the top result (if any meet the threshold)
179
+ top_result = sorted(results, key=lambda x: x[1], reverse=True)
180
+ if top_result:
181
+ top_quality = top_result[:1] # Select only the top result
182
+ output_html = render_html_output(top_quality)
183
+ explanation = generate_explanation(text, top_quality)
184
+ else: # Handle case no predictions >= 0.95
185
+ output_html = "<div style='color: red;'>No quality tag met the prediction probability threshold (>= 0.95).</div>"
186
+ explanation = ""
187
+
188
+
189
+ return output_html, "", explanation
190
+
191
+ def render_html_output(top_qualities):
192
+ #Simplified to show only the top prediction
193
+ styles = """
194
+ <style>
195
+ .quality-container {
196
+ font-family: Arial, sans-serif;
197
+ text-align: center;
198
+ margin-top: 20px;
199
+ }
200
+ .quality-label, .ranking {
201
+ display: inline-block;
202
+ padding: 0.5em 1em;
203
+ font-size: 18px;
204
+ font-weight: bold;
205
+ color: white;
206
+ background-color: #007bff;
207
+ border-radius: 0.5rem;
208
+ margin-right: 10px;
209
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
210
+ }
211
+ </style>
212
+ """
213
+ if not top_qualities: # Handle empty case
214
+ return styles + "<div class='quality-container'>No Top Prediction</div>"
215
+
216
+ quality, _ = top_qualities[0] #We know there is only one
217
+ html_content = f"""
218
+ <div class="quality-container">
219
+ <span class="ranking">Top Prediction</span>
220
+ <span class="quality-label">{quality}</span>
221
+ </div>
222
+ """
223
+ return styles + html_content
224
+
225
+ example_texts = [
226
+ ["The algorithm does not accurately distinguish between the positive and negative classes during edge cases.\n\nEnvironment: Production\nReproduction: Run the classifier on the test dataset with known edge cases."],
227
+ ["The regression tests do not cover scenarios involving concurrent user sessions.\n\nEnvironment: Test automation suite\nReproduction: Update the test scripts to include tests for concurrent sessions."],
228
+ ["There is frequent miscommunication between the development and QA teams regarding feature specifications.\n\nEnvironment: Inter-team meetings\nReproduction: Audit recent communication logs and meeting notes between the teams."],
229
+ ["The service-oriented architecture does not effectively isolate failures, leading to cascading failures across services.\n\nEnvironment: Microservices architecture\nReproduction: Simulate a service failure and observe the impact on other services."]
230
+ ]
231
+ # Improved CSS for better layout and appearance
232
+ css = """
233
+ .quality-container {
234
+ font-family: Arial, sans-serif;
235
+ text-align: center;
236
+ margin-top: 20px;
237
+ padding: 10px;
238
+ border: 1px solid #ddd;
239
+ border-radius: 8px;
240
+ background-color: #f9f9f9;
241
+ }
242
+ .quality-label, .ranking {
243
+ display: inline-block;
244
+ padding: 0.5em 1em;
245
+ font-size: 18px;
246
+ font-weight: bold;
247
+ color: white;
248
+ background-color: #007bff;
249
+ border-radius: 0.5rem;
250
+ margin-right: 10px;
251
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
252
+ }
253
+ .explanation-box {
254
+ border: 1px solid #ccc;
255
+ padding: 15px;
256
+ margin-top: 15px;
257
+ border-radius: 8px;
258
+ background-color: #fff;
259
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
260
+ line-height: 1.5;
261
+ }
262
+ .explanation-box p {
263
+ margin: 8px 0;
264
+ }
265
+ .explanation-box b {
266
+ color: #007bff;
267
+ }
268
+ """
269
+ interface = gr.Interface(
270
+ fn=main_interface,
271
+ inputs=gr.Textbox(lines=7, label="Issue Description", placeholder="Enter your issue text here"),
272
+ outputs=[
273
+ gr.HTML(label="Prediction Output"),
274
+ gr.Textbox(label="Predictions", visible=False),
275
+ gr.Markdown(label="Explanation")
276
+ ],
277
+ title="QualityTagger",
278
+ description="This tool classifies text into different quality domains such as Security, Usability,Mantainability, Reliability etc., and provides explanations.",
279
+ examples=example_texts,
280
+ css=css # Apply the CSS
281
+ )
282
+ interface.launch(share=True)