pyresearch commited on
Commit
4cad844
·
verified ·
1 Parent(s): b0e9ffc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -61
app.py CHANGED
@@ -5,9 +5,9 @@ from clarifai_grpc.grpc.api.status import status_code_pb2
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
-
9
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
10
- tokenizer = AutoTokenizer.from_pretraine
11
 
12
  # GPT-4 credentials
13
  PAT_GPT4 = "3ca5bd8b0f2244eb8d0e4b2838fc3cf1"
@@ -37,86 +37,80 @@ APP_ID_NEWSGUARDIAN = "your_app_id"
37
  MODEL_ID_NEWSGUARDIAN = "your_model_id"
38
  MODEL_VERSION_ID_NEWSGUARDIAN = "your_model_version_id"
39
 
 
40
 
41
  # Set up gRPC channel for NewsGuardian model
42
  channel_tts = ClarifaiChannel.get_grpc_channel()
43
  stub_tts = service_pb2_grpc.V2Stub(channel_tts)
44
  metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
45
- userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS)
46
 
47
  # Streamlit app
48
- st.title("NewsGuardian and phi-2 Text Generation")
49
-
50
- # Function to generate text using the "microsoft/phi-2" model
51
- def generate_phi2_text(input_text):
52
- inputs = tokenizer_phi2(input_text, return_tensors="pt", return_attention_mask=False)
53
- outputs = model_phi2.generate(**inputs, max_length=200)
54
- generated_text = tokenizer_phi2.batch_decode(outputs)[0]
55
- return generated_text
56
-
57
- # User input for phi-2 model
58
- raw_text_phi2 = st.text_area("Enter text for phi-2 model")
59
-
60
- # Button to generate result using "microsoft/phi-2" model
61
- if st.button("Generate text with phi-2 model"):
62
- if raw_text_phi2:
63
- generated_text_phi2 = generate_phi2_text(raw_text_phi2)
64
- st.text("Generated text with phi-2 model")
65
- st.text(generated_text_phi2)
66
- else:
67
- st.warning("Please enter text for phi-2 model")
68
-
69
- # User input for selecting the model
70
- model_type = st.selectbox("Select Model", ["NewsGuardian model", "DALL-E", "phi-2"])
71
- raw_text_news_guardian = st.text_area("This news is real or fake?")
72
- image_upload_news_guardian = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
73
-
74
- # Button to generate result for the selected model
75
- if st.button("Generate Result"):
76
  if model_type == "NewsGuardian model":
77
  # Set up gRPC channel for NewsGuardian model
78
- channel_news_guardian = ClarifaiChannel.get_grpc_channel()
79
- stub_news_guardian = service_pb2_grpc.V2Stub(channel_news_guardian)
80
- metadata_news_guardian = (('authorization', 'Key ' + PAT_NEWSGUARDIAN),)
81
- userDataObject_news_guardian = resources_pb2.UserAppIDSet(user_id=USER_ID_NEWSGUARDIAN, app_id=APP_ID_NEWSGUARDIAN)
82
 
83
  # Prepare the request for NewsGuardian model
84
- input_data_news_guardian = resources_pb2.Data()
85
 
86
- if raw_text_news_guardian:
87
- input_data_news_guardian.text.raw = raw_text_news_guardian
88
 
89
- if image_upload_news_guardian is not None:
90
- image_bytes_news_guardian = image_upload_news_guardian.read()
91
- input_data_news_guardian.image.base64 = image_bytes_news_guardian
92
 
93
- post_model_outputs_response_news_guardian = stub_news_guardian.PostModelOutputs(
94
  service_pb2.PostModelOutputsRequest(
95
- user_app_id=userDataObject_news_guardian,
96
- model_id=MODEL_ID_NEWSGUARDIAN,
97
- version_id=MODEL_VERSION_ID_NEWSGUARDIAN,
98
- inputs=[resources_pb2.Input(data=input_data_news_guardian)]
99
  ),
100
- metadata=metadata_news_guardian # Use metadata directly in the gRPC request
101
  )
102
 
103
  # Check if the request was successful for NewsGuardian model
104
- if post_model_outputs_response_news_guardian.status.code != status_code_pb2.SUCCESS:
105
- st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_news_guardian.status.description}")
106
  else:
107
  # Get the output for NewsGuardian model
108
- output_news_guardian = post_model_outputs_response_news_guardian.outputs[0].data
109
 
110
  # Display the result for NewsGuardian model
111
- if output_news_guardian.HasField("image"):
112
- st.image(output_news_guardian.image.base64, caption='Generated Image (NewsGuardian model)', use_column_width=True)
113
- elif output_news_guardian.HasField("text"):
114
  # Display the text result
115
- st.text(output_news_guardian.text.raw)
116
 
117
  # Convert text to speech and play the audio
 
 
118
  tts_input_data = resources_pb2.Data()
119
- tts_input_data.text.raw = output_news_guardian.text.raw
120
 
121
  tts_response = stub_tts.PostModelOutputs(
122
  service_pb2.PostModelOutputsRequest(
@@ -125,7 +119,7 @@ if st.button("Generate Result"):
125
  version_id=MODEL_VERSION_ID_TTS,
126
  inputs=[resources_pb2.Input(data=tts_input_data)]
127
  ),
128
- metadata=metadata_tts # Use the same metadata for TTS
129
  )
130
 
131
  # Check if the TTS request was successful
@@ -133,7 +127,7 @@ if st.button("Generate Result"):
133
  tts_output = tts_response.outputs[0].data
134
  st.audio(tts_output.audio.base64, format='audio/wav')
135
  else:
136
- st.error(f"TTS API request failed: {tts_response.status.description}")
137
 
138
  elif model_type == "DALL-E":
139
  # Set up gRPC channel for DALL-E
@@ -145,8 +139,8 @@ if st.button("Generate Result"):
145
  # Prepare the request for DALL-E
146
  input_data_dalle = resources_pb2.Data()
147
 
148
- if raw_text_news_guardian:
149
- input_data_dalle.text.raw = raw_text_news_guardian
150
 
151
  post_model_outputs_response_dalle = stub_dalle.PostModelOutputs(
152
  service_pb2.PostModelOutputsRequest(
@@ -171,6 +165,44 @@ if st.button("Generate Result"):
171
  elif output_dalle.HasField("text"):
172
  st.text(output_dalle.text.raw)
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # Add the beautiful social media icon section
175
  st.markdown("""
176
  <div align="center">
@@ -193,4 +225,4 @@ st.markdown("""
193
  <img src="https://user-images.githubusercontent.com/34125851/226601355-ffe0b597-9840-4e10-bbef-43d6c74b5a9e.png" width="2%" alt="" /></a>
194
  </div>
195
  <hr>
196
- """, unsafe_allow_html=True)
 
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
+ torch.set_default_device("cpu")
9
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True).to(device)
10
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
11
 
12
  # GPT-4 credentials
13
  PAT_GPT4 = "3ca5bd8b0f2244eb8d0e4b2838fc3cf1"
 
37
  MODEL_ID_NEWSGUARDIAN = "your_model_id"
38
  MODEL_VERSION_ID_NEWSGUARDIAN = "your_model_version_id"
39
 
40
+ #
41
 
42
  # Set up gRPC channel for NewsGuardian model
43
  channel_tts = ClarifaiChannel.get_grpc_channel()
44
  stub_tts = service_pb2_grpc.V2Stub(channel_tts)
45
  metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
46
+ userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS,)
47
 
48
  # Streamlit app
49
+ st.title("NewsGuardian")
50
+
51
+
52
+ # Inserting logo
53
+ st.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTdA-MJ_SUCRgLs1prqudpMdaX4x-x10Zqlwp7cpzXWCMM9xjBAJYWdJsDlLoHBqNpj8qs&usqp=CAU")
54
+ # Function to get gRPC channel for NewsGuardian model
55
+ def get_tts_channel():
56
+ channel_tts = ClarifaiChannel.get_grpc_channel()
57
+ return channel_tts, channel_tts.metadata
58
+
59
+
60
+
61
+ # User input
62
+ model_type = st.selectbox("Select Model", ["NewsGuardian model","NewsGuardian model"])
63
+ raw_text = st.text_area("This news is real or fake?")
64
+ image_upload = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
65
+
66
+ # Button to generate result
67
+ if st.button("NewsGuardian News Result"):
 
 
 
 
 
 
 
 
 
68
  if model_type == "NewsGuardian model":
69
  # Set up gRPC channel for NewsGuardian model
70
+ channel_gpt4 = ClarifaiChannel.get_grpc_channel()
71
+ stub_gpt4 = service_pb2_grpc.V2Stub(channel_gpt4)
72
+ metadata_gpt4 = (('authorization', 'Key ' + PAT_GPT4),)
73
+ userDataObject_gpt4 = resources_pb2.UserAppIDSet(user_id=USER_ID_GPT4, app_id=APP_ID_GPT4)
74
 
75
  # Prepare the request for NewsGuardian model
76
+ input_data_gpt4 = resources_pb2.Data()
77
 
78
+ if raw_text:
79
+ input_data_gpt4.text.raw = raw_text
80
 
81
+ if image_upload is not None:
82
+ image_bytes_gpt4 = image_upload.read()
83
+ input_data_gpt4.image.base64 = image_bytes_gpt4
84
 
85
+ post_model_outputs_response_gpt4 = stub_gpt4.PostModelOutputs(
86
  service_pb2.PostModelOutputsRequest(
87
+ user_app_id=userDataObject_gpt4,
88
+ model_id=MODEL_ID_GPT4,
89
+ version_id=MODEL_VERSION_ID_GPT4,
90
+ inputs=[resources_pb2.Input(data=input_data_gpt4)]
91
  ),
92
+ metadata=metadata_gpt4 # Use metadata directly in the gRPC request
93
  )
94
 
95
  # Check if the request was successful for NewsGuardian model
96
+ if post_model_outputs_response_gpt4.status.code != status_code_pb2.SUCCESS:
97
+ st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_gpt4.status.description}")
98
  else:
99
  # Get the output for NewsGuardian model
100
+ output_gpt4 = post_model_outputs_response_gpt4.outputs[0].data
101
 
102
  # Display the result for NewsGuardian model
103
+ if output_gpt4.HasField("image"):
104
+ st.image(output_gpt4.image.base64, caption='Generated Image (NewsGuardian model)', use_column_width=True)
105
+ elif output_gpt4.HasField("text"):
106
  # Display the text result
107
+ st.text(output_gpt4.text.raw)
108
 
109
  # Convert text to speech and play the audio
110
+ stub_tts = service_pb2_grpc.V2Stub(channel_gpt4) # Use the same channel for TTS
111
+
112
  tts_input_data = resources_pb2.Data()
113
+ tts_input_data.text.raw = output_gpt4.text.raw
114
 
115
  tts_response = stub_tts.PostModelOutputs(
116
  service_pb2.PostModelOutputsRequest(
 
119
  version_id=MODEL_VERSION_ID_TTS,
120
  inputs=[resources_pb2.Input(data=tts_input_data)]
121
  ),
122
+ metadata=metadata_gpt4 # Use the same metadata for TTS
123
  )
124
 
125
  # Check if the TTS request was successful
 
127
  tts_output = tts_response.outputs[0].data
128
  st.audio(tts_output.audio.base64, format='audio/wav')
129
  else:
130
+ st.error(f"NewsGuardian model API request failed: {tts_response.status.description}")
131
 
132
  elif model_type == "DALL-E":
133
  # Set up gRPC channel for DALL-E
 
139
  # Prepare the request for DALL-E
140
  input_data_dalle = resources_pb2.Data()
141
 
142
+ if raw_text:
143
+ input_data_dalle.text.raw = raw_text
144
 
145
  post_model_outputs_response_dalle = stub_dalle.PostModelOutputs(
146
  service_pb2.PostModelOutputsRequest(
 
165
  elif output_dalle.HasField("text"):
166
  st.text(output_dalle.text.raw)
167
 
168
+ elif model_type == "NewsGuardian model":
169
+ # Set up gRPC channel for NewsGuardian model
170
+ channel_tts = ClarifaiChannel.get_grpc_channel()
171
+ stub_tts = service_pb2_grpc.V2Stub(channel_tts)
172
+ metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
173
+ userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS)
174
+
175
+ # Prepare the request for NewsGuardian model
176
+ input_data_tts = resources_pb2.Data()
177
+
178
+ if raw_text:
179
+ input_data_tts.text.raw = raw_text
180
+
181
+ post_model_outputs_response_tts = stub_tts.PostModelOutputs(
182
+ service_pb2.PostModelOutputsRequest(
183
+ user_app_id=userDataObject_tts,
184
+ model_id=MODEL_ID_TTS,
185
+ version_id=MODEL_VERSION_ID_TTS,
186
+ inputs=[resources_pb2.Input(data=input_data_tts)]
187
+ ),
188
+ metadata=metadata_tts
189
+ )
190
+
191
+ # Check if the request was successful for NewsGuardian model
192
+ if post_model_outputs_response_tts.status.code != status_code_pb2.SUCCESS:
193
+ st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_tts.status.description}")
194
+ else:
195
+ # Get the output for NewsGuardian model
196
+ output_tts = post_model_outputs_response_tts.outputs[0].data
197
+
198
+ # Display the result for NewsGuardian model
199
+ if output_tts.HasField("text"):
200
+ st.text(output_tts.text.raw)
201
+
202
+ if output_tts.HasField("audio"):
203
+ st.audio(output_tts.audio.base64, format='audio/wav')
204
+
205
+
206
  # Add the beautiful social media icon section
207
  st.markdown("""
208
  <div align="center">
 
225
  <img src="https://user-images.githubusercontent.com/34125851/226601355-ffe0b597-9840-4e10-bbef-43d6c74b5a9e.png" width="2%" alt="" /></a>
226
  </div>
227
  <hr>
228
+ """, unsafe_allow_html=True)