pyresearch commited on
Commit
e67784d
·
verified ·
1 Parent(s): de2fcea

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -16
app.py CHANGED
@@ -1,26 +1,198 @@
1
- import torch
2
  import streamlit as st
 
 
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  torch.set_default_device("cpu")
6
-
7
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
8
  tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
9
 
10
- st.title("Text Generation with phi-2 Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Input text area for the user
13
- user_input = st.text_area("Enter some text:", '''def print_prime(n):
14
- """
15
- Print all primes between 1 and n
16
- """''')
17
 
18
- # Generate text based on user input
19
- if st.button("Generate Text"):
20
- inputs = tokenizer(user_input, return_tensors="pt", return_attention_mask=False)
 
 
 
21
  outputs = model.generate(**inputs, max_length=200)
22
  generated_text = tokenizer.batch_decode(outputs)[0]
23
-
24
- # Display the generated text
25
- st.subheader("Generated Text:")
26
- st.write(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
3
+ from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
4
+ from clarifai_grpc.grpc.api.status import status_code_pb2
5
+ import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
  torch.set_default_device("cpu")
9
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True).to(device)
 
10
  tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
11
 
12
+ # GPT-4 credentials
13
+ PAT_GPT4 = "3ca5bd8b0f2244eb8d0e4b2838fc3cf1"
14
+ USER_ID_GPT4 = "openai"
15
+ APP_ID_GPT4 = "chat-completion"
16
+ MODEL_ID_GPT4 = "openai-gpt-4-vision"
17
+ MODEL_VERSION_ID_GPT4 = "266df29bc09843e0aee9b7bf723c03c2"
18
+
19
+ # DALL-E credentials
20
+ PAT_DALLE = "bfdeb4029ef54d23a2e608b0aa4c00e4"
21
+ USER_ID_DALLE = "openai"
22
+ APP_ID_DALLE = "dall-e"
23
+ MODEL_ID_DALLE = "dall-e-3"
24
+ MODEL_VERSION_ID_DALLE = "dc9dcb6ee67543cebc0b9a025861b868"
25
+
26
+ # TTS credentials
27
+ PAT_TTS = "bfdeb4029ef54d23a2e608b0aa4c00e4"
28
+ USER_ID_TTS = "openai"
29
+ APP_ID_TTS = "tts"
30
+ MODEL_ID_TTS = "openai-tts-1"
31
+ MODEL_VERSION_ID_TTS = "fff6ce1fd487457da95b79241ac6f02d"
32
+
33
+ # NewsGuardian model credentials
34
+ PAT_NEWSGUARDIAN = "your_news_guardian_pat"
35
+ USER_ID_NEWSGUARDIAN = "your_user_id"
36
+ APP_ID_NEWSGUARDIAN = "your_app_id"
37
+ MODEL_ID_NEWSGUARDIAN = "your_model_id"
38
+ MODEL_VERSION_ID_NEWSGUARDIAN = "your_model_version_id"
39
+
40
+ # Set up gRPC channel for NewsGuardian model
41
+ channel_tts = ClarifaiChannel.get_grpc_channel()
42
+ stub_tts = service_pb2_grpc.V2Stub(channel_tts)
43
+ metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
44
+ userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS)
45
 
46
+ # Streamlit app
47
+ st.title("NewsGuardian")
 
 
 
48
 
49
+ # Inserting logo
50
+ st.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTdA-MJ_SUCRgLs1prqudpMdaX4x-x10Zqlwp7cpzXWCMM9xjBAJYWdJsDlLoHBqNpj8qs&usqp=CAU")
51
+
52
+ # Function to generate text using the "microsoft/phi-2" model
53
+ def generate_phi2_text(input_text):
54
+ inputs = tokenizer(input_text, return_tensors="pt", return_attention_mask=False)
55
  outputs = model.generate(**inputs, max_length=200)
56
  generated_text = tokenizer.batch_decode(outputs)[0]
57
+ return generated_text
58
+
59
+ # User input
60
+ raw_text_phi2 = st.text_area("Enter text for phi-2 model")
61
+
62
+ # Button to generate result using "microsoft/phi-2" model
63
+ if st.button("NewsGuardian model Generated fake news with phi-2"):
64
+ if raw_text_phi2:
65
+ generated_text_phi2 = generate_phi2_text(raw_text_phi2)
66
+ st.text("NewsGuardian model Generated fake news with phi-2")
67
+ st.text(generated_text_phi2)
68
+ else:
69
+ st.warning("Please enter news phi-2 model")
70
+
71
+ # User input
72
+ model_type = st.selectbox("Select Model", ["NewsGuardian model", "DALL-E"])
73
+ raw_text_news_guardian = st.text_area("This news is real or fake?")
74
+ image_upload_news_guardian = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
75
+
76
+ # Button to generate result for NewsGuardian model
77
+ if st.button("NewsGuardian News Result"):
78
+ if model_type == "NewsGuardian model":
79
+ # Set up gRPC channel for NewsGuardian model
80
+ channel_news_guardian = ClarifaiChannel.get_grpc_channel()
81
+ stub_news_guardian = service_pb2_grpc.V2Stub(channel_news_guardian)
82
+ metadata_news_guardian = (('authorization', 'Key ' + PAT_NEWSGUARDIAN),)
83
+ userDataObject_news_guardian = resources_pb2.UserAppIDSet(user_id=USER_ID_NEWSGUARDIAN, app_id=APP_ID_NEWSGUARDIAN)
84
+
85
+ # Prepare the request for NewsGuardian model
86
+ input_data_news_guardian = resources_pb2.Data()
87
+
88
+ if raw_text_news_guardian:
89
+ input_data_news_guardian.text.raw = raw_text_news_guardian
90
+
91
+ if image_upload_news_guardian is not None:
92
+ image_bytes_news_guardian = image_upload_news_guardian.read()
93
+ input_data_news_guardian.image.base64 = image_bytes_news_guardian
94
+
95
+ post_model_outputs_response_news_guardian = stub_news_guardian.PostModelOutputs(
96
+ service_pb2.PostModelOutputsRequest(
97
+ user_app_id=userDataObject_news_guardian,
98
+ model_id=MODEL_ID_NEWSGUARDIAN,
99
+ version_id=MODEL_VERSION_ID_NEWSGUARDIAN,
100
+ inputs=[resources_pb2.Input(data=input_data_news_guardian)]
101
+ ),
102
+ metadata=metadata_news_guardian # Use metadata directly in the gRPC request
103
+ )
104
+
105
+ # Check if the request was successful for NewsGuardian model
106
+ if post_model_outputs_response_news_guardian.status.code != status_code_pb2.SUCCESS:
107
+ st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_news_guardian.status.description}")
108
+ else:
109
+ # Get the output for NewsGuardian model
110
+ output_news_guardian = post_model_outputs_response_news_guardian.outputs[0].data
111
+
112
+ # Display the result for NewsGuardian model
113
+ if output_news_guardian.HasField("image"):
114
+ st.image(output_news_guardian.image.base64, caption='Generated Image (NewsGuardian model)', use_column_width=True)
115
+ elif output_news_guardian.HasField("text"):
116
+ # Display the text result
117
+ st.text(output_news_guardian.text.raw)
118
+
119
+ # Convert text to speech and play the audio
120
+ tts_input_data = resources_pb2.Data()
121
+ tts_input_data.text.raw = output_news_guardian.text.raw
122
+
123
+ tts_response = stub_tts.PostModelOutputs(
124
+ service_pb2.PostModelOutputsRequest(
125
+ user_app_id=userDataObject_tts,
126
+ model_id=MODEL_ID_TTS,
127
+ version_id=MODEL_VERSION_ID_TTS,
128
+ inputs=[resources_pb2.Input(data=tts_input_data)]
129
+ ),
130
+ metadata=metadata_tts # Use the same metadata for TTS
131
+ )
132
+
133
+ # Check if the TTS request was successful
134
+ if tts_response.status.code == status_code_pb2.SUCCESS:
135
+ tts_output = tts_response.outputs[0].data
136
+ st.audio(tts_output.audio.base64, format='audio/wav')
137
+ else:
138
+ st.error(f"TTS API request failed: {tts_response.status.description}")
139
+
140
+ elif model_type == "DALL-E":
141
+ # Set up gRPC channel for DALL-E
142
+ channel_dalle = ClarifaiChannel.get_grpc_channel()
143
+ stub_dalle = service_pb2_grpc.V2Stub(channel_dalle)
144
+ metadata_dalle = (('authorization', 'Key ' + PAT_DALLE),)
145
+ userDataObject_dalle = resources_pb2.UserAppIDSet(user_id=USER_ID_DALLE, app_id=APP_ID_DALLE)
146
+
147
+ # Prepare the request for DALL-E
148
+ input_data_dalle = resources_pb2.Data()
149
+
150
+ if raw_text_news_guardian:
151
+ input_data_dalle.text.raw = raw_text_news_guardian
152
+
153
+ post_model_outputs_response_dalle = stub_dalle.PostModelOutputs(
154
+ service_pb2.PostModelOutputsRequest(
155
+ user_app_id=userDataObject_dalle,
156
+ model_id=MODEL_ID_DALLE,
157
+ version_id=MODEL_VERSION_ID_DALLE,
158
+ inputs=[resources_pb2.Input(data=input_data_dalle)]
159
+ ),
160
+ metadata=metadata_dalle
161
+ )
162
+
163
+ # Check if the request was successful for DALL-E
164
+ if post_model_outputs_response_dalle.status.code != status_code_pb2.SUCCESS:
165
+ st.error(f"DALL-E API request failed: {post_model_outputs_response_dalle.status.description}")
166
+ else:
167
+ # Get the output for DALL-E
168
+ output_dalle = post_model_outputs_response_dalle.outputs[0].data
169
+
170
+ # Display the result for DALL-E
171
+ if output_dalle.HasField("image"):
172
+ st.image(output_dalle.image.base64, caption='Generated Image (DALL-E)', use_column_width=True)
173
+ elif output_dalle.HasField("text"):
174
+ st.text(output_dalle.text.raw)
175
+
176
+ # Add the beautiful social media icon section
177
+ st.markdown("""
178
+ <div align="center">
179
+ <a href="https://github.com/pyresearch/pyresearch" style="text-decoration:none;">
180
+ <img src="https://user-images.githubusercontent.com/34125851/226594737-c21e2dda-9cc6-42ef-b4e7-a685fea4a21d.png" width="2%" alt="" /></a>
181
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
182
+ <a href="https://www.linkedin.com/company/pyresearch/" style="text-decoration:none;">
183
+ <img src="https://user-images.githubusercontent.com/34125851/226596446-746ffdd0-a47e-4452-84e3-bf11ec2aa26a.png" width="2%" alt="" /></a>
184
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
185
+ <a href="https://twitter.com/Noorkhokhar10" style="text-decoration:none;">
186
+ <img src="https://user-images.githubusercontent.com/34125851/226599162-9b11194e-4998-440a-ba94-c8a5e1cdc676.png" width="2%" alt="" /></a>
187
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
188
+ <a href="https://www.youtube.com/@Pyresearch" style="text-decoration:none;">
189
+ <img src="https://user-images.githubusercontent.com/34125851/226599904-7d5cc5c0-89d2-4d1e-891e-19bee1951744.png" width="2%" alt="" /></a>
190
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
191
+ <a href="https://www.facebook.com/Pyresearch" style="text-decoration:none;">
192
+ <img src="https://user-images.githubusercontent.com/34125851/226600380-a87a9142-e8e0-4ec9-bf2c-dd6e9da2f05a.png" width="2%" alt="" /></a>
193
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
194
+ <a href="https://www.instagram.com/pyresearch/" style="text-decoration:none;">
195
+ <img src="https://user-images.githubusercontent.com/34125851/226601355-ffe0b597-9840-4e10-bbef-43d6c74b5a9e.png" width="2%" alt="" /></a>
196
+ </div>
197
+ <hr>
198
+ """, unsafe_allow_html=True)