Spaces:
Running
Running
| # from dotenv import find_dotenv, load_dotenv | |
| # _ = load_dotenv(find_dotenv()) | |
| import solara | |
| import polars as pl | |
| df = pl.read_csv( | |
| "https://drive.google.com/uc?export=download&id=1uD3h7xYxr9EoZ0Ggoh99JtQXa3AxtxyU" | |
| ) | |
| import string | |
| df = df.with_columns( | |
| pl.Series("Album", [string.capwords(album) for album in df["Album"]]) | |
| ) | |
| df = df.with_columns(pl.Series("Song", [string.capwords(song) for song in df["Song"]])) | |
| df = df.with_columns(pl.col("Lyrics").fill_null("None")) | |
| df = df.with_columns( | |
| text=pl.lit("# ") | |
| + pl.col("Album") | |
| + pl.lit(": ") | |
| + pl.col("Song") | |
| + pl.lit("\n\n") | |
| + pl.col("Lyrics") | |
| ) | |
| import shutil | |
| import lancedb | |
| shutil.rmtree("test_lancedb", ignore_errors=True) | |
| db = lancedb.connect("test_lancedb") | |
| from lancedb.embeddings import get_registry | |
| embeddings = ( | |
| get_registry() | |
| .get("sentence-transformers") | |
| .create(name="TaylorAI/gte-tiny", device="cpu") | |
| ) | |
| from lancedb.pydantic import LanceModel, Vector | |
| class Songs(LanceModel): | |
| Song: str | |
| Lyrics: str | |
| Album: str | |
| Artist: str | |
| text: str = embeddings.SourceField() | |
| vector: Vector(embeddings.ndims()) = embeddings.VectorField() | |
| table = db.create_table("Songs", schema=Songs) | |
| table.add(data=df) | |
| import os | |
| from typing import Optional | |
| from langchain_community.chat_models import ChatOpenAI | |
| class ChatOpenRouter(ChatOpenAI): | |
| openai_api_base: str | |
| openai_api_key: str | |
| model_name: str | |
| def __init__( | |
| self, | |
| model_name: str, | |
| openai_api_key: Optional[str] = None, | |
| openai_api_base: str = "https://openrouter.ai/api/v1", | |
| **kwargs, | |
| ): | |
| openai_api_key = os.getenv("OPENROUTER_API_KEY") | |
| super().__init__( | |
| openai_api_base=openai_api_base, | |
| openai_api_key=openai_api_key, | |
| model_name=model_name, | |
| **kwargs, | |
| ) | |
| llm_openrouter = ChatOpenRouter(model_name="meta-llama/llama-3.1-405b-instruct", temperature=0.1) | |
| def get_relevant_texts(query, table=table): | |
| results = ( | |
| table.search(query) | |
| .limit(5) | |
| .to_polars() | |
| ) | |
| return " ".join([results["text"][i] + "\n\n---\n\n" for i in range(5)]) | |
| def generate_prompt(query, table=table): | |
| return ( | |
| "Answer the question based only on the following context:\n\n" | |
| + get_relevant_texts(query, table) | |
| + "\n\nQuestion: " | |
| + query | |
| ) | |
| def generate_response(query, table=table): | |
| prompt = generate_prompt(query, table) | |
| response = llm_openrouter.invoke(input=prompt) | |
| return response.content | |
| import kuzu | |
| shutil.rmtree("test_kuzudb", ignore_errors=True) | |
| db = kuzu.Database("test_kuzudb") | |
| conn = kuzu.Connection(db) | |
| # Create schema | |
| conn.execute("CREATE NODE TABLE ARTIST(name STRING, PRIMARY KEY (name))") | |
| conn.execute("CREATE NODE TABLE ALBUM(name STRING, PRIMARY KEY (name))") | |
| conn.execute("CREATE NODE TABLE SONG(ID SERIAL, name STRING, lyrics STRING, PRIMARY KEY(ID))") | |
| conn.execute("CREATE REL TABLE IN_ALBUM(FROM SONG TO ALBUM)") | |
| conn.execute("CREATE REL TABLE FROM_ARTIST(FROM ALBUM TO ARTIST)"); | |
| # Insert nodes | |
| for artist in df["Artist"].unique(): | |
| conn.execute(f"CREATE (artist:ARTIST {{name: '{artist}'}})") | |
| for album in df["Album"].unique(): | |
| conn.execute(f"""CREATE (album:ALBUM {{name: "{album}"}})""") | |
| for song, lyrics in df.select(["Song", "text"]).unique().rows(): | |
| replaced_lyrics = lyrics.replace('"', "'") | |
| conn.execute( | |
| f"""CREATE (song:SONG {{name: "{song}", lyrics: "{replaced_lyrics}"}})""" | |
| ) | |
| # Insert edges | |
| for song, album, lyrics in df.select(["Song", "Album", "text"]).rows(): | |
| replaced_lyrics = lyrics.replace('"', "'") | |
| conn.execute( | |
| f""" | |
| MATCH (song:SONG), (album:ALBUM) | |
| WHERE song.name = "{song}" AND song.lyrics = "{replaced_lyrics}" AND album.name = "{album}" | |
| CREATE (song)-[:IN_ALBUM]->(album) | |
| """ | |
| ) | |
| for album, artist in df.select(["Album", "Artist"]).unique().rows(): | |
| conn.execute( | |
| f""" | |
| MATCH (album:ALBUM), (artist:ARTIST) WHERE album.name = "{album}" AND artist.name = "{artist}" | |
| CREATE (album)-[:FROM_ARTIST]->(artist) | |
| """ | |
| ) | |
| response = conn.execute( | |
| """ | |
| MATCH (a:ALBUM {name: 'The Black Album'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name | |
| """ | |
| ) | |
| df_response = response.get_as_pl() | |
| from langchain_community.graphs import KuzuGraph | |
| graph = KuzuGraph(db) | |
| def generate_kuzu_prompt(user_query): | |
| return """Task: Generate Kùzu Cypher statement to query a graph database. | |
| Instructions: | |
| Generate the Kùzu dialect of Cypher with the following rules in mind: | |
| 1. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`. | |
| 2. Do not include triple backticks ``` in your response. Return only Cypher. | |
| 3. Do not return any notes or comments in your response. | |
| Use only the provided relationship types and properties in the schema. | |
| Do not use any other relationship types or properties that are not provided. | |
| Schema:\n""" + graph.get_schema + """\nExample: | |
| The question is:\n"Which songs does the load album have?" | |
| MATCH (a:ALBUM {name: 'Load'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name | |
| Note: Do not include any explanations or apologies in your responses. | |
| Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. | |
| Do not include any text except the generated Cypher statement. | |
| The question is:\n""" + user_query | |
| def generate_final_prompt(query,cypher_query,col_name,_values): | |
| return f"""You are an assistant that helps to form nice and human understandable answers. | |
| The information part contains the provided information that you must use to construct an answer. | |
| The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. | |
| Make the answer sound as a response to the question. Do not mention that you based the result on the given information. | |
| Here is an example: | |
| Question: Which managers own Neo4j stocks? | |
| Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] | |
| Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. | |
| Follow this example when generating answers. | |
| If the provided information is empty, say that you don't know the answer. | |
| Query:\n{cypher_query} | |
| Information: | |
| [{col_name}: {_values}] | |
| Question: {query} | |
| Helpful Answer: | |
| """ | |
| def generate_kg_response(query): | |
| prompt = generate_kuzu_prompt(query) | |
| cypher_query_response = llm_openrouter.invoke(input=prompt) | |
| cypher_query = cypher_query_response.content | |
| response = conn.execute( | |
| f""" | |
| {cypher_query} | |
| """ | |
| ) | |
| df = response.get_as_pl() | |
| col_name = df.columns[0] | |
| _values = df[col_name].to_list() | |
| final_prompt = generate_final_prompt(query,cypher_query,col_name,_values) | |
| final_response = llm_openrouter.invoke(input=final_prompt) | |
| final_response = final_response.content | |
| return final_response, cypher_query | |
| query = solara.reactive("How many songs does the black album have?") | |
| def Page(): | |
| with solara.Column(margin=10): | |
| solara.Markdown("# Metallica Song Finder graph-only") | |
| solara.InputText("Enter some query:", query, continuous_update=False) | |
| if query.value != "": | |
| response, cypher_query = generate_kg_response(query.value) | |
| solara.Markdown("## Answer:") | |
| solara.Markdown(response) | |
| solara.Markdown("## Cypher query:") | |
| solara.Markdown(cypher_query) | |