File size: 4,949 Bytes
1e81c91
 
 
 
 
faaabaa
1e81c91
 
 
7bc7ddb
1e81c91
6495d55
bfad56c
 
1e81c91
 
6495d55
1e81c91
 
 
6495d55
 
 
faaabaa
6495d55
 
 
 
 
 
 
 
 
 
1e81c91
 
7bc7ddb
 
 
 
 
1e81c91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc7ddb
1e81c91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from utils.youtube_extractor import YoutubeExtractor
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from typing import List, Dict, Optional, Union
from class_mod.rest_qdrant import RestQdrantClient
import uuid

class DataImporter:
    def __init__(self, qdrant_url: str = "https://qdrant.taspolsd.dev", collection_name: str = "demo_bge_m3"):
        self.model = SentenceTransformer("BAAI/bge-m3")
        # self.client = QdrantClient(url=qdrant_url)
        self.qdrant_url = qdrant_url
        self.client = None
        self.collection_name = collection_name
        self.youtube_extractor = YoutubeExtractor()
        self._init_qdrant()
        
        # Create collection if it doesn't exist
        self._create_collection()
    def _init_qdrant(self):
        """Initialize Qdrant client with error handling"""
        try:
            self.client = RestQdrantClient(url=self.qdrant_url)
            # Test connection
            self.client.get_collections()
            self.qdrant_available = True
            print(f"Successfully connected to Qdrant at {self.qdrant_url}")
            self._create_collection()
        except Exception as e:
            print(f"Warning: Could not connect to Qdrant: {e}")
            print("Running in offline mode - vector operations will be disabled")
            self.client = None
            self.qdrant_available = False
    def _create_collection(self):
        try:
            collections = self.client.get_collection(self.collection_name)
            if collections:
                print(f"Collection '{self.collection_name}' already exists.")
                return

            self.client.recreate_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=1024, distance=Distance.COSINE)
            )
            print(f"Collection '{self.collection_name}' created successfully")
        except Exception as e:
            print(f"Error creating collection: {e}")
    
    def encode_text(self, texts: Union[str, List[str]]) -> List[List[float]]:
        if isinstance(texts, str):
            texts = [texts]
        
        embeddings = self.model.encode(texts, normalize_embeddings=True)
        return embeddings.tolist()
    
    def insert_text(self, text: str, metadata: Optional[Dict] = None, custom_id: Optional[str] = None) -> str:
        point_id = custom_id or str(uuid.uuid4())
        embedding = self.encode_text(text)[0]
        
        payload = {"text": text}
        if metadata:
            payload.update(metadata)
        
        self.client.upsert(
            collection_name=self.collection_name,
            points=[PointStruct(id=point_id, vector=embedding, payload=payload)]
        )
        
        print(f"Inserted text with ID: {point_id}")
        return point_id
    
    def insert_texts(self, texts: List[str], metadata_list: Optional[List[Dict]] = None) -> List[str]:
        embeddings = self.encode_text(texts)
        point_ids = [str(uuid.uuid4()) for _ in texts]
        
        points = []
        for i, (text, embedding, point_id) in enumerate(zip(texts, embeddings, point_ids)):
            payload = {"text": text}
            if metadata_list and i < len(metadata_list):
                payload.update(metadata_list[i])
            
            points.append(PointStruct(id=point_id, vector=embedding, payload=payload))
        
        self.client.upsert(collection_name=self.collection_name, points=points)
        print(f"Inserted {len(texts)} texts")
        return point_ids
    
    def insert_from_youtube(self, video_id: str, metadata: Optional[Dict] = None) -> Optional[str]:
        try:
            # Extract text from YouTube (assuming your YoutubeExtractor has this method)
            text = self.youtube_extractor.get_full_text(video_id)
            if text:
                video_metadata = {"source": "youtube", "video_id": video_id}
                if metadata:
                    video_metadata.update(metadata)
                
                return self.insert_text(text, video_metadata)
            return None
        except Exception as e:
            print(f"Error extracting from YouTube: {e}")
            return None
    
    def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
        query_embedding = self.encode_text(query)[0]
        
        results = self.client.search(
            collection_name=self.collection_name,
            query_vector=query_embedding,
            limit=limit
        )
        
        return [
            {
                "id": result.id,
                "score": result.score,
                "text": result.payload.get("text", ""),
                "metadata": {k: v for k, v in result.payload.items() if k != "text"}
            }
            for result in results
        ]