File size: 6,927 Bytes
563eae5
e446af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acae57e
e446af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f528982
e446af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# /agent.py

import torch
import gc
import os
from transformers import AutoProcessor, AutoModelForImageTextToText
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from typing import Optional, Sequence
from typing_extensions import TypedDict

# Helper function to mimic LangGraph's add_messages
def add_messages(left: Sequence[BaseMessage], right: Sequence[BaseMessage]) -> Sequence[BaseMessage]:
    """Concatenates two sequences of messages."""
    return left + right

class AgentState(TypedDict):
    """Defines the state of our agent."""
    audio_path: Optional[str]
    image_path: Optional[str]
    transcribed_text: Optional[str]
    image_description: Optional[str]
    news_report: Sequence[BaseMessage]
    final_message: Optional[str]


class NewsReporterAgent:
    def __init__(self):
        """Initializes the agent by loading the model and processor from the Hub."""
        print("--- πŸš€ INITIALIZING MODEL (this may take a moment) ---")
        
        # Define the Hugging Face Hub model ID
        model_id = "google/gemma-3n-E2B-it" # google/gemma-3n-E4B-it # 4b taking too long
        
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(f"  > Using device: {self.device}")
        
        # Load directly from the Hugging Face Hub
        self.processor = AutoProcessor.from_pretrained(model_id)
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_id, torch_dtype="auto"
        ).to(self.device)
        
        print("--- βœ… MODEL READY ---")

    def _generate(self, messages: list) -> str:
        """Private helper to run model inference and handle memory."""
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(self.device, dtype=self.model.dtype)

        outputs = self.model.generate(**inputs, max_new_tokens=128, disable_compile=True) # 64 token for faster inference
        text = self.processor.decode(outputs[0][inputs["input_ids"].shape[-1]:])

        del inputs
        del outputs
        torch.cuda.empty_cache()
        gc.collect()
        return text

    def transcribe_audio(self, state: AgentState) -> dict:
        """Transcribes the audio file specified in the state."""
        print("--- 🎀 TRANSCRIBING AUDIO ---")
        audio_path = state.get('audio_path')
        if not audio_path:
            return {}
        messages = [{"role": "user", "content": [{"type": "audio", "audio": audio_path}, {"type": "text", "text": "Transcribe the following audio. Provide only the transcribed text."}]}]
        transcribed_text = self._generate(messages)
        print("  > Transcription generated.")
        return {"transcribed_text": transcribed_text}

    def describe_image(self, state: AgentState) -> dict:
        """Generates a description for the image specified in the state."""
        print("--- πŸ–ΌοΈ DESCRIBING IMAGE ---")
        image_path = state.get('image_path')
        if not image_path:
            return {}
        messages = [{"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": "Describe this image in detail."}]}]
        image_description = self._generate(messages)
        print("  > Description generated.")
        return {"image_description": image_description}

    # This is the agent
    def create_report(self, state: AgentState) -> dict:
        """Generates a news report from transcription and/or image description."""
        print("--- ✍️ GENERATING NEWS REPORT ---")
        context_parts = ["You are an expert news reporter. Your task is to write a clear, concise, and factual news report...", "Synthesize all available information into a single, coherent story..."]
        transcribed_text = state.get('transcribed_text')
        image_description = state.get('image_description')
        if not transcribed_text and not image_description:
            return {"news_report": [AIMessage(content="No input provided to generate a report.")]}
        if transcribed_text:
            context_parts.append(f"--- Transcribed Audio ---\n\"{transcribed_text}\"")
        if image_description:
            context_parts.append(f"--- Image Description ---\n\"{image_description}\"")
        prompt = "\n\n".join(context_parts)
        report_content = self._generate([{"role": "user", "content": [{"type": "text", "text": prompt}]}])
        print("  > Report generated successfully.")
        return {"news_report": [AIMessage(content=report_content)]}

    def revise_report(self, state: AgentState) -> dict:
        """Revises the news report based on the latest human feedback."""
        print("--- πŸ”„ REVISING REPORT ---")
        # Extract context from state
        transcribed = state.get("transcribed_text", "Not available.")
        image_desc = state.get("image_description", "Not available.")
        human_feedback = next((msg.content for msg in reversed(state["news_report"]) if isinstance(msg, HumanMessage)), None)
        last_ai_report = next((msg.content for msg in reversed(state["news_report"]) if isinstance(msg, AIMessage)), None)
        
        prompt = f"""You are a professional news editor. Revise the news report to address the feedback...
                    **Original Source Information:**
                    --- Transcribed Audio ---
                    "{transcribed}"
                    --- Image Description ---
                    "{image_desc}"
                    **Current Draft of News Report:**
                    "{last_ai_report}"
                    **Latest Human Feedback:**
                    "{human_feedback}"
                    Provide only the full, revised news report..."""
        revised_content = self._generate([{"role": "user", "content": [{"type": "text", "text": prompt}]}])
        print("  > Revision complete.")
        return {"news_report": add_messages(state["news_report"], [AIMessage(content=revised_content)])}

    def save_report(self, state: AgentState) -> dict:
        """Saves the latest AI-generated news report to a text file."""
        print("--- πŸ’Ύ SAVING REPORT ---")
        latest_report_msg = next((msg for msg in reversed(state["news_report"]) if isinstance(msg, AIMessage)), None)
        if not latest_report_msg:
            return {"final_message": "Error: No report available to save."}
        output_dir = "saved_reports"
        os.makedirs(output_dir, exist_ok=True)
        filename = os.path.join(output_dir, f"news_report_{len(os.listdir(output_dir)) + 1}.txt")
        with open(filename, "w", encoding="utf-8") as f:
            f.write(latest_report_msg.content)
        final_message = f"βœ… Report saved to: **{filename}**"
        print(f"  > {final_message}")
        return {"final_message": final_message}