Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	changes to fast api and neural searcher
Browse files- app.py +11 -3
- neural_searcher.py +6 -4
- requirements.txt +1 -0
    	
        app.py
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            from fastapi import FastAPI
         | 
| 2 | 
             
            from neural_searcher import NeuralSearcher
         | 
| 3 | 
             
            from huggingface_hub import login
         | 
| 4 | 
             
            import os
         | 
| @@ -10,5 +10,13 @@ app = FastAPI() | |
| 10 | 
             
            neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
         | 
| 11 |  | 
| 12 | 
             
            @app.get("/api/search")
         | 
| 13 | 
            -
            def  | 
| 14 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from fastapi import FastAPI, HTTPException
         | 
| 2 | 
             
            from neural_searcher import NeuralSearcher
         | 
| 3 | 
             
            from huggingface_hub import login
         | 
| 4 | 
             
            import os
         | 
|  | |
| 10 | 
             
            neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
         | 
| 11 |  | 
| 12 | 
             
            @app.get("/api/search")
         | 
| 13 | 
            +
            async def search(q: str):
         | 
| 14 | 
            +
                if not q:
         | 
| 15 | 
            +
                    raise HTTPException(status_code=400, detail="Bad request.")
         | 
| 16 | 
            +
                
         | 
| 17 | 
            +
                try:
         | 
| 18 | 
            +
                    data = await neural_searcher.search(text=q)
         | 
| 19 | 
            +
                    return data
         | 
| 20 | 
            +
                except:
         | 
| 21 | 
            +
                    raise HTTPException(status_code=500, detail="Internal server error.")
         | 
| 22 | 
            +
             | 
    	
        neural_searcher.py
    CHANGED
    
    | @@ -17,24 +17,26 @@ class NeuralSearcher: | |
| 17 | 
             
                    dense_query = self.dense_model.encode(text).tolist()
         | 
| 18 | 
             
                    sparse_query = self.sparse_model.query_embed(text)
         | 
| 19 |  | 
| 20 | 
            -
                    search_result = self.qdrant_client. | 
| 21 | 
             
                        collection_name= self.collection_name,
         | 
|  | |
| 22 | 
             
                        prefetch=[
         | 
| 23 | 
             
                            models.Prefetch(
         | 
| 24 | 
             
                                query=dense_query,
         | 
| 25 | 
             
                                using=os.getenv('DENSE_MODEL'),
         | 
| 26 | 
            -
                                limit= | 
| 27 | 
             
                            ),
         | 
| 28 | 
             
                            models.Prefetch(
         | 
| 29 | 
             
                                query=next(sparse_query).as_object(),
         | 
| 30 | 
             
                                using=os.getenv('SPARSE_MODEL'),
         | 
| 31 | 
            -
                                limit= | 
| 32 | 
             
                            )
         | 
| 33 | 
             
                        ],
         | 
| 34 | 
             
                        query=models.FusionQuery(
         | 
| 35 | 
             
                            fusion=models.Fusion.RRF
         | 
| 36 | 
             
                        ),
         | 
| 37 | 
            -
                         | 
|  | |
| 38 | 
             
                    ).points
         | 
| 39 |  | 
| 40 | 
             
                    payloads = [hit.payload for hit in search_result]
         | 
|  | |
| 17 | 
             
                    dense_query = self.dense_model.encode(text).tolist()
         | 
| 18 | 
             
                    sparse_query = self.sparse_model.query_embed(text)
         | 
| 19 |  | 
| 20 | 
            +
                    search_result = self.qdrant_client.query_points_groups(
         | 
| 21 | 
             
                        collection_name= self.collection_name,
         | 
| 22 | 
            +
                        group_by="dbid",
         | 
| 23 | 
             
                        prefetch=[
         | 
| 24 | 
             
                            models.Prefetch(
         | 
| 25 | 
             
                                query=dense_query,
         | 
| 26 | 
             
                                using=os.getenv('DENSE_MODEL'),
         | 
| 27 | 
            +
                                limit=100
         | 
| 28 | 
             
                            ),
         | 
| 29 | 
             
                            models.Prefetch(
         | 
| 30 | 
             
                                query=next(sparse_query).as_object(),
         | 
| 31 | 
             
                                using=os.getenv('SPARSE_MODEL'),
         | 
| 32 | 
            +
                                limit=100
         | 
| 33 | 
             
                            )
         | 
| 34 | 
             
                        ],
         | 
| 35 | 
             
                        query=models.FusionQuery(
         | 
| 36 | 
             
                            fusion=models.Fusion.RRF
         | 
| 37 | 
             
                        ),
         | 
| 38 | 
            +
                        score_threshold=0.8,
         | 
| 39 | 
            +
                        limit = 10
         | 
| 40 | 
             
                    ).points
         | 
| 41 |  | 
| 42 | 
             
                    payloads = [hit.payload for hit in search_result]
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -8,3 +8,4 @@ python-dotenv | |
| 8 | 
             
            qdrant-client
         | 
| 9 | 
             
            qdrant-client[fastembed]>=1.8.2
         | 
| 10 | 
             
            sentence-transformers
         | 
|  | 
|  | |
| 8 | 
             
            qdrant-client
         | 
| 9 | 
             
            qdrant-client[fastembed]>=1.8.2
         | 
| 10 | 
             
            sentence-transformers
         | 
| 11 | 
            +
            firebase
         | 
