Spaces:
Running
Running
| import yaml | |
| import argparse | |
| import sys | |
| import os | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'model')) | |
| from model.multi_task_graph_router import graph_router_prediction | |
| import pandas as pd | |
| from openai import OpenAI | |
| import torch | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config_file", type=str, default="configs/config.yaml") | |
| parser.add_argument("--query", type=str, default="What is the derivative of f(x) = x^3 + 2x^2 - x + 5?", | |
| help="Input query to process") | |
| args = parser.parse_args() | |
| # Initialize OpenAI client for NVIDIA API | |
| client = OpenAI( | |
| base_url="https://integrate.api.nvidia.com/v1", | |
| api_key="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", | |
| timeout=60, | |
| max_retries=2 | |
| ) | |
| def model_prompting( | |
| llm_model: str, | |
| prompt: str, | |
| max_token_num: int = 1024, | |
| temperature: float = 0.2, | |
| top_p: float = 0.7, | |
| stream: bool = True, | |
| ) -> str: | |
| """ | |
| Get a response from an LLM model using the OpenAI-compatible NVIDIA API. | |
| Args: | |
| llm_model: Name of the model to use (e.g., "meta/llama-3.1-8b-instruct") | |
| prompt: Input prompt text | |
| max_token_num: Maximum number of tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Top-p sampling parameter | |
| stream: Whether to stream the response | |
| Returns: | |
| Generated text response | |
| """ | |
| try: | |
| completion = client.chat.completions.create( | |
| model=llm_model, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=max_token_num, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stream=stream | |
| ) | |
| response_text = "" | |
| for chunk in completion: | |
| if chunk.choices[0].delta.content is not None: | |
| response_text += chunk.choices[0].delta.content | |
| return response_text | |
| except Exception as e: | |
| raise Exception(f"API call failed: {str(e)}") | |
| def generate_task_description(query: str) -> str: | |
| """ | |
| Generate a concise task description using LLM API. | |
| Args: | |
| query: The user's input query | |
| Returns: | |
| A concise task description | |
| """ | |
| prompt = f"""Analyze the following query and provide a concise task description that identifies the type of task and domain it belongs to. Focus on the core problem type and relevant domain areas. | |
| Query: {query} | |
| Please provide a brief, focused task description that captures: | |
| 1. The primary task type (e.g., mathematical calculation, text analysis, coding, reasoning, etc.) | |
| 2. The relevant domain or subject area | |
| 3. The complexity level or approach needed | |
| Keep the description concise and informative. Respond with just the task description, no additional formatting.""" | |
| try: | |
| task_description = model_prompting( | |
| llm_model="meta/llama-3.1-8b-instruct", | |
| prompt=prompt, | |
| max_token_num=256, | |
| temperature=0.1, | |
| top_p=0.9, | |
| stream=True | |
| ) | |
| import pdb; pdb.set_trace() | |
| return task_description.strip() | |
| except Exception as e: | |
| print(f"Warning: Failed to generate task description via API: {str(e)}") | |
| # Fallback to a generic description | |
| return "General query processing task requiring analysis and response generation." | |
| from transformers import LongformerTokenizer, LongformerModel | |
| def get_cls_embedding(text, model_name="allenai/longformer-base-4096", device="cpu"): | |
| """ | |
| Extracts the [CLS] embedding from a given text using Longformer. | |
| Args: | |
| text (str): Input text | |
| model_name (str): Hugging Face model name | |
| device (str): "cpu" or "cuda" | |
| Returns: | |
| torch.Tensor: CLS embedding of shape (1, hidden_size) | |
| """ | |
| # Load tokenizer and model | |
| tokenizer = LongformerTokenizer.from_pretrained(model_name) | |
| model = LongformerModel.from_pretrained(model_name).to(device) | |
| model.eval() | |
| # Tokenize input | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=4096).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| cls_embedding = outputs.last_hidden_state[:, 0, :] # (1, hidden_size) | |
| return cls_embedding | |
| def main(): | |
| """Main execution function with error handling.""" | |
| try: | |
| with open(args.config_file, 'r', encoding='utf-8') as file: | |
| config = yaml.safe_load(file) | |
| print("Loading training data...") | |
| train_df = pd.read_csv(config['train_data_path']) | |
| train_df = train_df[train_df["task_name"] != 'quac'] | |
| print(f"Loaded {len(train_df)} training samples") | |
| input_user_query = args.query | |
| print(f"Input Query: {input_user_query}") | |
| # Generate embeddings for the query | |
| print("Generating query embedding...") | |
| user_query_embedding = get_cls_embedding(input_user_query).squeeze(0) | |
| # Call LLM to generate user_task_description | |
| print("Generating task description using LLM API...") | |
| user_task_description = generate_task_description(input_user_query) | |
| print(f"Generated Task Description: {user_task_description}") | |
| # Generate embeddings for the task description | |
| print("Generating task description embedding...") | |
| user_task_embedding = get_cls_embedding(user_task_description).squeeze(0) | |
| # Prepare test dataframe | |
| test_df = train_df.head(config['llm_num']).copy() | |
| test_df['query'] = input_user_query | |
| test_df['task_description'] = user_task_description | |
| test_df.loc[0, 'query_embedding'] = str(user_query_embedding) | |
| test_df.loc[0, 'task_description'] = str(user_task_embedding) | |
| print("Running graph router prediction...") | |
| graph_router_prediction(router_data_train=train_df, router_data_test=test_df, llm_path=config['llm_description_path'], | |
| llm_embedding_path=config['llm_embedding_path'], config=config) | |
| print("Pipeline completed successfully!") | |
| except FileNotFoundError as e: | |
| print(f"Error: Configuration file not found - {e}") | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"Error during execution: {e}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |