File size: 2,143 Bytes
2f264ab
 
 
 
 
8428312
 
2f264ab
 
 
 
 
cf1cddb
2f264ab
 
 
 
8428312
2f264ab
82915e5
2f264ab
 
cf1cddb
 
2f264ab
 
 
 
 
 
 
 
 
 
 
 
 
 
82915e5
 
 
 
 
 
 
 
 
2f264ab
82915e5
2f264ab
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging

import torch
from transformers import pipeline

logger = logging.getLogger(__name__)
logging.basicConfig(filename="pipeline.log", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")

class Pipeline:
    def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
        self.torch_pipe = pipeline(
            "text-generation",
            model_name,
            torch_dtype="auto",
            device_map="auto"
        )
        self.device = self._determine_device()
        logger.info(f"device type: {self.device}")
        self.messages = [
            {"role": "system", "content": """You are an expert flashcard creator.
            - You ALWAYS include a single knowledge item per flashcard.
            - You ALWAYS respond in valid JSON format.
            - You ALWAYS make flashcards accurate and comprehensivce.
            - If the text includes code snippets, you consider snippets a knowledge item testing the user's understanding of how to write the code and what it does.

            Format responses like the example below.

            EXAMPLE:
            [
                {"question": "What is AI?", "answer": "Artificial Intelligence."},
                {"question": "What is ML?", "answer": "Machine Learning."}
            ]
            """},
        ]

    def extract_flashcards(self, content: str = "", max_new_tokens: int = 1024) -> str:
        user_prompt = {"role": "user", "content": content}
        self.messages.append(user_prompt)
        try:
            response_message = self.torch_pipe(
                self.messages,
                max_new_tokens=max_new_tokens
            )[0]["generated_text"][-1]
            return response_message
        except Exception as e:
            logger.error(f"Error extracting flashcards: {str(e)}")
            raise ValueError(f"Error extraction flashcards: {str(e)}")

    def _determine_device(self) -> torch.device:
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif torch.backends.mps.is_available():
            return torch.device("mps")
        else:
            return torch.device("cpu")