File size: 6,058 Bytes
1e81c91
 
 
 
 
faaabaa
1e81c91
 
 
7bc7ddb
1e81c91
6495d55
bfad56c
 
1e81c91
 
6495d55
1e81c91
 
 
6495d55
 
 
c6f0ab2
6495d55
 
 
 
 
 
 
 
 
 
1e81c91
 
7bc7ddb
 
 
 
 
1e81c91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152241c
1e81c91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc7ddb
1e81c91
 
 
 
 
 
 
 
 
 
 
152241c
179cd4f
 
 
 
1e81c91
179cd4f
 
 
 
 
 
152241c
 
179cd4f
69623ca
179cd4f
 
bd4417e
 
 
 
179cd4f
69623ca
179cd4f
 
 
c6f0ab2
152241c
 
 
 
 
 
 
 
 
858f321
152241c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from utils.youtube_extractor import YoutubeExtractor
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from typing import List, Dict, Optional, Union
from class_mod.rest_qdrant import RestQdrantClient
import uuid

class DataImporter:
    def __init__(self, qdrant_url: str = "https://qdrant.taspolsd.dev", collection_name: str = "demo_bge_m3"):
        self.model = SentenceTransformer("BAAI/bge-m3")
        # self.client = QdrantClient(url=qdrant_url)
        self.qdrant_url = qdrant_url
        self.client = None
        self.collection_name = collection_name
        self.youtube_extractor = YoutubeExtractor()
        self._init_qdrant()
        
        # Create collection if it doesn't exist
        self._create_collection()
    def _init_qdrant(self):
        """Initialize Qdrant client with error handling"""
        try:
            self.client = RestQdrantClient(url=self.qdrant_url,timeout=15)
            # Test connection
            self.client.get_collections()
            self.qdrant_available = True
            print(f"Successfully connected to Qdrant at {self.qdrant_url}")
            self._create_collection()
        except Exception as e:
            print(f"Warning: Could not connect to Qdrant: {e}")
            print("Running in offline mode - vector operations will be disabled")
            self.client = None
            self.qdrant_available = False
    def _create_collection(self):
        try:
            collections = self.client.get_collection(self.collection_name)
            if collections:
                print(f"Collection '{self.collection_name}' already exists.")
                return

            self.client.recreate_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=1024, distance=Distance.COSINE)
            )
            print(f"Collection '{self.collection_name}' created successfully")
        except Exception as e:
            print(f"Error creating collection: {e}")
    
    def encode_text(self, texts: Union[str, List[str]]) -> List[List[float]]:
        if isinstance(texts, str):
            texts = [texts]
        
        embeddings = self.model.encode(texts, normalize_embeddings=True)
        return embeddings.tolist()
    
    def insert_text(self, text: str, metadata: Optional[Dict] = None, custom_id: Optional[str] = None) -> str:
        point_id = custom_id or str(uuid.uuid4())
        embedding = self.encode_text(text)[0]
        payload = {"text": text}
        
        if metadata:
            payload.update(metadata)
        
        self.client.upsert(
            collection_name=self.collection_name,
            points=[PointStruct(id=point_id, vector=embedding, payload=payload)]
        )
        
        print(f"Inserted text with ID: {point_id}")
        return point_id
    
    def insert_texts(self, texts: List[str], metadata_list: Optional[List[Dict]] = None) -> List[str]:
        embeddings = self.encode_text(texts)
        point_ids = [str(uuid.uuid4()) for _ in texts]
        
        points = []
        for i, (text, embedding, point_id) in enumerate(zip(texts, embeddings, point_ids)):
            payload = {"text": text}
            if metadata_list and i < len(metadata_list):
                payload.update(metadata_list[i])
            
            points.append(PointStruct(id=point_id, vector=embedding, payload=payload))
        
        self.client.upsert(collection_name=self.collection_name, points=points)
        print(f"Inserted {len(texts)} texts")
        return point_ids
    
    def insert_from_youtube(self, video_id: str, metadata: Optional[Dict] = None) -> Optional[str]:
        try:
            # Extract text from YouTube (assuming your YoutubeExtractor has this method)
            text = self.youtube_extractor.get_full_text(video_id)
            if text:
                video_metadata = {"source": "youtube", "video_id": video_id}
                if metadata:
                    video_metadata.update(metadata)
                
                return self.insert_text(text, video_metadata)
            return None
        except Exception as e:
            print(f"Error extracting from YouTube: {e}")
            return None
    
    def search_similar(self, query: str, limit: int = 1) -> List[Dict]:
        """Search with Qdrant availability check - always returns a list"""
        if not self.qdrant_available or not self.client:
            print("Warning: Qdrant not available, returning empty results")
            return []
        
        try:
            query_embedding = self.encode_text(query)[0]
            
            results = self.client.search(
                collection_name=self.collection_name,
                query_vector=query_embedding,
                limit=limit,
                timeout=15
            )
            print(f"Search results: {results}")
            return [
                {
                    "id": result['id'],
                    "score": float(result['score']) if result['score'] else 0.0,
                    "text": result['payload'].get("text", ""),
                    "metadata": {k: v for k, v in result['payload'].items() if k != "text"}
                }
                for result in results['result']
            ]
        except Exception as e:
            print(f"Error searching: {e}")
            raise ValueError(f"Search failed: {str(e)}")

    def coldStartDatabase(self):
        coldstart_texts = "I want to go to Chiang Mai"
        try:
            query_embedding = self.encode_text(coldstart_texts)[0]
            results = self.client.search(
                collection_name=self.collection_name,
                query_vector=query_embedding,
                limit=1,
                timeout=10
            )
            print(f"Cold start results: {results}")
        except Exception as e:
            print(f"finish cold start, with error: {e}")