File size: 4,949 Bytes
1e81c91 faaabaa 1e81c91 7bc7ddb 1e81c91 6495d55 bfad56c 1e81c91 6495d55 1e81c91 6495d55 faaabaa 6495d55 1e81c91 7bc7ddb 1e81c91 7bc7ddb 1e81c91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from utils.youtube_extractor import YoutubeExtractor
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from typing import List, Dict, Optional, Union
from class_mod.rest_qdrant import RestQdrantClient
import uuid
class DataImporter:
def __init__(self, qdrant_url: str = "https://qdrant.taspolsd.dev", collection_name: str = "demo_bge_m3"):
self.model = SentenceTransformer("BAAI/bge-m3")
# self.client = QdrantClient(url=qdrant_url)
self.qdrant_url = qdrant_url
self.client = None
self.collection_name = collection_name
self.youtube_extractor = YoutubeExtractor()
self._init_qdrant()
# Create collection if it doesn't exist
self._create_collection()
def _init_qdrant(self):
"""Initialize Qdrant client with error handling"""
try:
self.client = RestQdrantClient(url=self.qdrant_url)
# Test connection
self.client.get_collections()
self.qdrant_available = True
print(f"Successfully connected to Qdrant at {self.qdrant_url}")
self._create_collection()
except Exception as e:
print(f"Warning: Could not connect to Qdrant: {e}")
print("Running in offline mode - vector operations will be disabled")
self.client = None
self.qdrant_available = False
def _create_collection(self):
try:
collections = self.client.get_collection(self.collection_name)
if collections:
print(f"Collection '{self.collection_name}' already exists.")
return
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=1024, distance=Distance.COSINE)
)
print(f"Collection '{self.collection_name}' created successfully")
except Exception as e:
print(f"Error creating collection: {e}")
def encode_text(self, texts: Union[str, List[str]]) -> List[List[float]]:
if isinstance(texts, str):
texts = [texts]
embeddings = self.model.encode(texts, normalize_embeddings=True)
return embeddings.tolist()
def insert_text(self, text: str, metadata: Optional[Dict] = None, custom_id: Optional[str] = None) -> str:
point_id = custom_id or str(uuid.uuid4())
embedding = self.encode_text(text)[0]
payload = {"text": text}
if metadata:
payload.update(metadata)
self.client.upsert(
collection_name=self.collection_name,
points=[PointStruct(id=point_id, vector=embedding, payload=payload)]
)
print(f"Inserted text with ID: {point_id}")
return point_id
def insert_texts(self, texts: List[str], metadata_list: Optional[List[Dict]] = None) -> List[str]:
embeddings = self.encode_text(texts)
point_ids = [str(uuid.uuid4()) for _ in texts]
points = []
for i, (text, embedding, point_id) in enumerate(zip(texts, embeddings, point_ids)):
payload = {"text": text}
if metadata_list and i < len(metadata_list):
payload.update(metadata_list[i])
points.append(PointStruct(id=point_id, vector=embedding, payload=payload))
self.client.upsert(collection_name=self.collection_name, points=points)
print(f"Inserted {len(texts)} texts")
return point_ids
def insert_from_youtube(self, video_id: str, metadata: Optional[Dict] = None) -> Optional[str]:
try:
# Extract text from YouTube (assuming your YoutubeExtractor has this method)
text = self.youtube_extractor.get_full_text(video_id)
if text:
video_metadata = {"source": "youtube", "video_id": video_id}
if metadata:
video_metadata.update(metadata)
return self.insert_text(text, video_metadata)
return None
except Exception as e:
print(f"Error extracting from YouTube: {e}")
return None
def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
query_embedding = self.encode_text(query)[0]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
limit=limit
)
return [
{
"id": result.id,
"score": result.score,
"text": result.payload.get("text", ""),
"metadata": {k: v for k, v in result.payload.items() if k != "text"}
}
for result in results
]
|