File size: 3,943 Bytes
8a58cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Slides parser.

Contains parsers for .pptx files.

"""

import os
from pathlib import Path
from typing import Dict

from gpt_index.readers.file.base_parser import BaseParser


class PptxParser(BaseParser):
    """Powerpoint parser.

    Extract text, caption images, and specify slides.

    """

    def _init_parser(self) -> Dict:
        """Init parser."""
        try:
            from pptx import Presentation  # noqa: F401
        except ImportError:
            raise ImportError(
                "The package `python-pptx` is required to read Powerpoint files: "
                "`pip install python-pptx`"
            )
        try:
            import torch  # noqa: F401
        except ImportError:
            raise ImportError(
                "The package `pytorch` is required to caption images: "
                "`pip install torch`"
            )
        try:
            from transformers import (
                AutoTokenizer,
                VisionEncoderDecoderModel,
                ViTFeatureExtractor,
            )
        except ImportError:
            raise ImportError(
                "The package `transformers` is required to caption images: "
                "`pip install transformers`"
            )
        try:
            from PIL import Image  # noqa: F401
        except ImportError:
            raise ImportError(
                "PIL is required to read image files: " "`pip install Pillow`"
            )

        model = VisionEncoderDecoderModel.from_pretrained(
            "nlpconnect/vit-gpt2-image-captioning"
        )
        feature_extractor = ViTFeatureExtractor.from_pretrained(
            "nlpconnect/vit-gpt2-image-captioning"
        )
        tokenizer = AutoTokenizer.from_pretrained(
            "nlpconnect/vit-gpt2-image-captioning"
        )

        return {
            "feature_extractor": feature_extractor,
            "model": model,
            "tokenizer": tokenizer,
        }

    def caption_image(self, tmp_image_file: str) -> str:
        """Generate text caption of image."""
        import torch
        from PIL import Image

        model = self.parser_config["model"]
        feature_extractor = self.parser_config["feature_extractor"]
        tokenizer = self.parser_config["tokenizer"]

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)

        max_length = 16
        num_beams = 4
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

        i_image = Image.open(tmp_image_file)
        if i_image.mode != "RGB":
            i_image = i_image.convert(mode="RGB")

        pixel_values = feature_extractor(
            images=[i_image], return_tensors="pt"
        ).pixel_values
        pixel_values = pixel_values.to(device)

        output_ids = model.generate(pixel_values, **gen_kwargs)

        preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        return preds[0].strip()

    def parse_file(self, file: Path, errors: str = "ignore") -> str:
        """Parse file."""
        from pptx import Presentation

        presentation = Presentation(file)
        result = ""
        for i, slide in enumerate(presentation.slides):
            result += f"\n\nSlide #{i}: \n"
            for shape in slide.shapes:
                if hasattr(shape, "image"):
                    image = shape.image
                    # get image "file" contents
                    image_bytes = image.blob
                    # temporarily save the image to feed into model
                    image_filename = f"tmp_image.{image.ext}"
                    with open(image_filename, "wb") as f:
                        f.write(image_bytes)
                    result += f"\n Image: {self.caption_image(image_filename)}\n\n"

                    os.remove(image_filename)
                if hasattr(shape, "text"):
                    result += f"{shape.text}\n"

        return result