Naphat Sornwichai commited on
Commit
b4c6511
Β·
1 Parent(s): 995e28f

update major files

Browse files
Files changed (1) hide show
  1. app.py +35 -55
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
4
  import yt_dlp
5
  from openai import OpenAI
6
  import os
7
  import json
8
  import torchaudio
9
- import torchaudio.transforms as T
10
  import time
11
 
12
  # --- 1. Model & Pipeline Initialization ---
@@ -21,14 +20,16 @@ model_id = "nectec/Pathumma-whisper-th-medium"
21
 
22
  print(f"Using device: {device} with dtype: {torch_dtype}")
23
 
24
- # Load the model and processor directly
25
- # We will use the model's .generate() method for long-form transcription
26
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
27
- model_id, dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
 
 
28
  )
29
- model.to(device)
30
 
31
- processor = AutoProcessor.from_pretrained(model_id)
 
32
 
33
  print("Transcription model loaded successfully.")
34
 
@@ -59,10 +60,11 @@ def download_youtube_audio(url: str) -> str:
59
 
60
 
61
  # --- 3. Core Logic ---
62
- def transcribe_and_summarize(audio_file: str, youtube_url: str, progress=gr.Progress()):
63
  """
64
  Main function to process audio, transcribe, and summarize.
65
  This is a generator function to yield status updates and logs to the UI.
 
66
  """
67
  log_history = ""
68
  def log(message):
@@ -71,8 +73,8 @@ def transcribe_and_summarize(audio_file: str, youtube_url: str, progress=gr.Prog
71
  log_history += f"[{timestamp}] {message}\n"
72
  return log_history
73
 
74
- progress(0, desc="Starting...")
75
- yield log("Process started."), "", "", "Starting..."
76
 
77
  # Step 1: Get API Key and validate inputs
78
  api_key = os.getenv('TYPHOON_API')
@@ -84,55 +86,34 @@ def transcribe_and_summarize(audio_file: str, youtube_url: str, progress=gr.Prog
84
  # Step 2: Determine audio source and get file path
85
  filepath = ""
86
  if youtube_url:
87
- progress(0.1, desc="Downloading Audio...")
88
- yield log("YouTube link detected. Starting download."), "", "", "Downloading Audio..."
89
  try:
90
  filepath = download_youtube_audio(youtube_url)
91
- yield log(f"Audio downloaded successfully to '{filepath}'."), "", "", "Download Complete"
92
  except Exception as e:
93
- yield log(f"Error downloading from YouTube: {e}"), "", "", f"Error: {e}"
94
  return
95
  else:
96
  filepath = audio_file
97
- yield log(f"Processing uploaded file: '{filepath}'."), "", "", "Processing File..."
98
 
99
 
100
- # Step 3: Transcribe audio using the model's generate method for long-form audio
101
- progress(0.3, desc="Transcribing Audio...")
102
- yield log("Beginning audio transcription..."), "", "", "Transcribing Audio..."
103
  try:
104
- # Load audio file using torchaudio
105
- waveform, sr = torchaudio.load(filepath)
106
-
107
- # Resample to 16kHz if necessary, as Whisper expects this rate
108
- if sr != 16000:
109
- yield log(f"Original sample rate is {sr}Hz. Resampling to 16000Hz."), "", "", "Resampling..."
110
- resampler = T.Resample(orig_freq=sr, new_freq=16000)
111
- waveform = resampler(waveform)
112
-
113
- # Process the audio waveform to get input features
114
- input_features = processor(
115
- waveform.squeeze().numpy(),
116
- return_tensors="pt",
117
- sampling_rate=16000
118
- ).input_features.to(device, dtype=torch_dtype)
119
-
120
- # Generate token IDs from the input features, passing task and language directly
121
- predicted_ids = model.generate(input_features, language="th", task="transcribe")
122
-
123
- # Decode the token IDs to text
124
- transcribed_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
125
- yield log("Transcription complete."), transcribed_text, "", "Transcription Complete"
126
 
127
  except Exception as e:
128
  raise gr.Error(f"An error occurred during transcription: {str(e)}")
129
 
130
 
131
  # Step 4: Summarize with Typhoon LLM
132
- progress(0.8, desc="Generating Summary...")
133
- yield log("Sending transcription to Typhoon LLM for summarization."), transcribed_text, "", "Generating Summary..."
134
  if not transcribed_text or not transcribed_text.strip():
135
- yield log("Transcription is empty. Aborting summarization."), "", "Could not generate summary because the transcription is empty.", "Aborted"
136
  return
137
 
138
  # Initialize OpenAI client for Typhoon
@@ -166,7 +147,7 @@ The JSON object must have the following structure:
166
  temperature=0.7
167
  )
168
  summary_json_string = response.choices[0].message.content
169
- yield log("Received summary from Typhoon LLM. Parsing JSON."), transcribed_text, "", "Parsing Summary..."
170
 
171
  # Parse the JSON and format it as Markdown
172
  try:
@@ -182,14 +163,15 @@ The JSON object must have the following structure:
182
 
183
  # Build the blog post in Markdown format
184
  summary_markdown = f"# {title}\n\n"
185
- summary_markdown += f"{key_takeaway}\n\n"
186
  if main_ideas:
187
  summary_markdown += "## Key Features & Main Ideas\n\n"
 
188
  for idea in main_ideas:
189
- summary_markdown += f"- {idea}\n"
190
- summary_markdown += "\n"
191
- summary_markdown += f"## Conclusion\n\n{conclusion}"
192
- yield log("Successfully parsed and formatted summary."), transcribed_text, summary_markdown, "Formatting Complete"
193
 
194
  except (json.JSONDecodeError, AttributeError) as e:
195
  error_message = f"Failed to parse the summary from the AI. Raw response: {summary_json_string}"
@@ -199,8 +181,7 @@ The JSON object must have the following structure:
199
  raise gr.Error(f"Could not connect to the Typhoon API. Please check your API key. Error: {str(e)}")
200
 
201
  # Step 5: Return final results
202
- progress(1.0, desc="Done!")
203
- yield log("Process finished successfully."), transcribed_text, summary_markdown, "Done!"
204
 
205
  # --- 4. Gradio UI ---
206
  # Custom CSS for a beautiful, blog-like output.
@@ -270,8 +251,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
270
  )
271
 
272
  submit_button = gr.Button("πŸš€ Generate Blog Post", variant="primary")
273
- status_output = gr.Textbox(label="Status", interactive=False, lines=1)
274
- with gr.Accordion("πŸ“ View Process Log", open=False):
275
  log_output = gr.Textbox(label="Log", interactive=False, lines=10)
276
 
277
  with gr.Column(scale=2):
@@ -285,7 +265,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
285
  submit_button.click(
286
  fn=transcribe_and_summarize,
287
  inputs=[audio_file_input, youtube_url_input],
288
- outputs=[log_output, transcription_output, blog_summary_output, status_output]
289
  )
290
 
291
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import pipeline
4
  import yt_dlp
5
  from openai import OpenAI
6
  import os
7
  import json
8
  import torchaudio
 
9
  import time
10
 
11
  # --- 1. Model & Pipeline Initialization ---
 
20
 
21
  print(f"Using device: {device} with dtype: {torch_dtype}")
22
 
23
+ # Initialize the ASR pipeline, which is more robust for handling inputs
24
+ pipe = pipeline(
25
+ task="automatic-speech-recognition",
26
+ model=model_id,
27
+ dtype=torch_dtype,
28
+ device=device,
29
  )
 
30
 
31
+ # Set the language and task for the pipeline
32
+ pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="th", task="transcribe")
33
 
34
  print("Transcription model loaded successfully.")
35
 
 
60
 
61
 
62
  # --- 3. Core Logic ---
63
+ def transcribe_and_summarize(audio_file: str, youtube_url: str):
64
  """
65
  Main function to process audio, transcribe, and summarize.
66
  This is a generator function to yield status updates and logs to the UI.
67
+ No longer uses gr.Progress, shows loading state in the output component itself.
68
  """
69
  log_history = ""
70
  def log(message):
 
73
  log_history += f"[{timestamp}] {message}\n"
74
  return log_history
75
 
76
+ loading_message = "⏳ Please wait, your article is being generated..."
77
+ yield log("Process started."), "", loading_message
78
 
79
  # Step 1: Get API Key and validate inputs
80
  api_key = os.getenv('TYPHOON_API')
 
86
  # Step 2: Determine audio source and get file path
87
  filepath = ""
88
  if youtube_url:
89
+ yield log("YouTube link detected. Starting download."), "", loading_message
 
90
  try:
91
  filepath = download_youtube_audio(youtube_url)
92
+ yield log(f"Audio downloaded successfully to '{filepath}'."), "", loading_message
93
  except Exception as e:
94
+ yield log(f"Error downloading from YouTube: {e}"), "", ""
95
  return
96
  else:
97
  filepath = audio_file
98
+ yield log(f"Processing uploaded file: '{filepath}'."), "", loading_message
99
 
100
 
101
+ # Step 3: Transcribe audio using the pipeline for robustness
102
+ yield log("Beginning audio transcription... This may take a while for long audio."), "", loading_message
 
103
  try:
104
+ # The pipeline handles resampling, chunking, and batching automatically
105
+ result = pipe(filepath, chunk_length_s=30, batch_size=8, return_timestamps=False)
106
+ transcribed_text = result["text"]
107
+ yield log("Transcription complete."), transcribed_text, loading_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  except Exception as e:
110
  raise gr.Error(f"An error occurred during transcription: {str(e)}")
111
 
112
 
113
  # Step 4: Summarize with Typhoon LLM
114
+ yield log("Sending transcription to Typhoon LLM for summarization."), transcribed_text, loading_message
 
115
  if not transcribed_text or not transcribed_text.strip():
116
+ yield log("Transcription is empty. Aborting summarization."), "", "Could not generate summary because the transcription is empty."
117
  return
118
 
119
  # Initialize OpenAI client for Typhoon
 
147
  temperature=0.7
148
  )
149
  summary_json_string = response.choices[0].message.content
150
+ yield log("Received summary from Typhoon LLM. Parsing JSON."), transcribed_text, loading_message
151
 
152
  # Parse the JSON and format it as Markdown
153
  try:
 
163
 
164
  # Build the blog post in Markdown format
165
  summary_markdown = f"# {title}\n\n"
166
+ summary_markdown += f"<p>{key_takeaway}</p>\n\n"
167
  if main_ideas:
168
  summary_markdown += "## Key Features & Main Ideas\n\n"
169
+ summary_markdown += "<ul>\n"
170
  for idea in main_ideas:
171
+ summary_markdown += f" <li>{idea}</li>\n"
172
+ summary_markdown += "</ul>\n\n"
173
+ summary_markdown += f"## Conclusion\n\n<p>{conclusion}</p>"
174
+ yield log("Successfully parsed and formatted summary."), transcribed_text, summary_markdown
175
 
176
  except (json.JSONDecodeError, AttributeError) as e:
177
  error_message = f"Failed to parse the summary from the AI. Raw response: {summary_json_string}"
 
181
  raise gr.Error(f"Could not connect to the Typhoon API. Please check your API key. Error: {str(e)}")
182
 
183
  # Step 5: Return final results
184
+ yield log("Process finished successfully."), transcribed_text, summary_markdown
 
185
 
186
  # --- 4. Gradio UI ---
187
  # Custom CSS for a beautiful, blog-like output.
 
251
  )
252
 
253
  submit_button = gr.Button("πŸš€ Generate Blog Post", variant="primary")
254
+ with gr.Accordion("πŸ“ View Process Log", open=True):
 
255
  log_output = gr.Textbox(label="Log", interactive=False, lines=10)
256
 
257
  with gr.Column(scale=2):
 
265
  submit_button.click(
266
  fn=transcribe_and_summarize,
267
  inputs=[audio_file_input, youtube_url_input],
268
+ outputs=[log_output, transcription_output, blog_summary_output]
269
  )
270
 
271
  if __name__ == "__main__":