mkoot007 commited on
Commit
be16c65
·
1 Parent(s): 6b717c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -1,16 +1,14 @@
1
  import pandas as pd
2
  import streamlit as st
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
- import re
5
  import torch
6
 
7
  # Load the pre-trained model and tokenizer
8
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
9
- model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
10
 
11
- def analyze_text(text, confidence_threshold=0.6):
12
  # Preprocess the text
13
- text = re.sub(r"[^\w\s]", "", text)
14
  text = text.lower()
15
 
16
  # Encode the text
@@ -19,14 +17,12 @@ def analyze_text(text, confidence_threshold=0.6):
19
  # Classify the text
20
  with torch.no_grad():
21
  output = model(**encoded_text)
22
- logits = output.logits
23
- predictions = logits.argmax(-1).item()
24
- confidence = torch.softmax(logits, dim=1)[0][predictions].item()
25
 
26
- if confidence > confidence_threshold:
27
- if predictions == 0:
28
- return "Job Interview Related"
29
- return "Not Job Interview Related"
30
 
31
  def count_job_related_messages(data):
32
  job_related_count = 0
@@ -34,7 +30,7 @@ def count_job_related_messages(data):
34
 
35
  for message in data["message"]:
36
  result = analyze_text(message)
37
- if result == "Job Interview Related":
38
  job_related_count += 1
39
  else:
40
  not_job_related_count += 1
@@ -42,7 +38,7 @@ def count_job_related_messages(data):
42
  return job_related_count, not_job_related_count
43
 
44
  # Streamlit application
45
- st.title("Job Interview Message Analyzer")
46
 
47
  uploaded_file = st.file_uploader("Upload CSV file")
48
  user_input = st.text_input("Enter text")
@@ -57,14 +53,14 @@ if uploaded_file:
57
  result = analyze_text(message)
58
  results.append(result)
59
 
60
- data["Job Interview Related"] = results
61
 
62
  # Count job-related messages
63
  job_related_count, not_job_related_count = count_job_related_messages(data)
64
 
65
  st.dataframe(data)
66
- st.write(f"Job Interview Related Messages: {job_related_count}")
67
- st.write(f"Not Job Interview Related Messages: {not_job_related_count}")
68
  elif user_input:
69
  # Analyze user-input text
70
  result = analyze_text(user_input)
 
1
  import pandas as pd
2
  import streamlit as st
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
4
  import torch
5
 
6
  # Load the pre-trained model and tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-imdb")
8
+ model = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-imdb")
9
 
10
+ def analyze_text(text):
11
  # Preprocess the text
 
12
  text = text.lower()
13
 
14
  # Encode the text
 
17
  # Classify the text
18
  with torch.no_grad():
19
  output = model(**encoded_text)
20
+ predictions = output.logits.argmax(-1).item()
 
 
21
 
22
+ if predictions == 1: # For IMDb sentiment analysis, 1 indicates positive sentiment
23
+ return "Job Related"
24
+ else:
25
+ return "Not Job Related"
26
 
27
  def count_job_related_messages(data):
28
  job_related_count = 0
 
30
 
31
  for message in data["message"]:
32
  result = analyze_text(message)
33
+ if result == "Job Related":
34
  job_related_count += 1
35
  else:
36
  not_job_related_count += 1
 
38
  return job_related_count, not_job_related_count
39
 
40
  # Streamlit application
41
+ st.title("Job Related Message Analyzer")
42
 
43
  uploaded_file = st.file_uploader("Upload CSV file")
44
  user_input = st.text_input("Enter text")
 
53
  result = analyze_text(message)
54
  results.append(result)
55
 
56
+ data["Job Related"] = results
57
 
58
  # Count job-related messages
59
  job_related_count, not_job_related_count = count_job_related_messages(data)
60
 
61
  st.dataframe(data)
62
+ st.write(f"Job Related Messages: {job_related_count}")
63
+ st.write(f"Not Job Related Messages: {not_job_related_count}")
64
  elif user_input:
65
  # Analyze user-input text
66
  result = analyze_text(user_input)