akhil2808 commited on
Commit
5c74af6
·
1 Parent(s): 3aa1e32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -2,21 +2,30 @@ import gradio as gr
2
  import torch
3
  from transformers import BertTokenizerFast, BertForSequenceClassification
4
 
5
- def predict_news_category(text, model, tokenizer, device):
6
- inputs = tokenizer(text, truncation=True, padding=True, max_length=512, return_tensors='pt').to(device)
7
- model.to(device)
 
 
8
  outputs = model(**inputs)
9
  probs = outputs[0].softmax(1)
 
 
10
  _, predicted_category = torch.max(probs, dim=1)
 
11
  return predicted_category.item()
12
 
 
13
  model = BertForSequenceClassification.from_pretrained('akhil2808/EPICS-PROJECT')
14
- model.eval()
 
 
15
  tokenizer = BertTokenizerFast.from_pretrained('akhil2808/EPICS-PROJECT')
16
 
 
17
  def detect_news_category(text):
18
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
- category = predict_news_category(text, model, tokenizer, device)
20
  prediction_dict = {1: 'Real News', 0: 'Fake News'}
21
  return prediction_dict[category]
22
 
@@ -28,3 +37,4 @@ iface = gr.Interface(fn=detect_news_category,
28
  theme='huggingface')
29
 
30
  iface.launch()
 
 
2
  import torch
3
  from transformers import BertTokenizerFast, BertForSequenceClassification
4
 
5
+ def predict_news_category(text, model, tokenizer):
6
+ # Tokenize input text
7
+ inputs = tokenizer(text, truncation=True, padding=True, max_length=512, return_tensors='pt')
8
+
9
+ # Predict
10
  outputs = model(**inputs)
11
  probs = outputs[0].softmax(1)
12
+
13
+ # Get the predicted category
14
  _, predicted_category = torch.max(probs, dim=1)
15
+
16
  return predicted_category.item()
17
 
18
+ # Load your model
19
  model = BertForSequenceClassification.from_pretrained('akhil2808/EPICS-PROJECT')
20
+ model.eval() # Set the model to evaluation mode
21
+
22
+ # Load tokenizer
23
  tokenizer = BertTokenizerFast.from_pretrained('akhil2808/EPICS-PROJECT')
24
 
25
+ # Function to predict news category
26
  def detect_news_category(text):
27
+ category = predict_news_category(text, model, tokenizer)
28
+ # Map the prediction to fake or real news
29
  prediction_dict = {1: 'Real News', 0: 'Fake News'}
30
  return prediction_dict[category]
31
 
 
37
  theme='huggingface')
38
 
39
  iface.launch()
40
+