Spaces:
Sleeping
Sleeping
| """ | |
| LLM Interface Module for Cross-Domain Uncertainty Quantification | |
| This module provides a unified interface for interacting with large language models, | |
| supporting multiple model architectures and uncertainty quantification methods. | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Any, Union, Optional | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from tqdm import tqdm | |
| class LLMInterface: | |
| """Interface for interacting with large language models with uncertainty quantification.""" | |
| def __init__( | |
| self, | |
| model_name: str, | |
| model_type: str = "causal", | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
| cache_dir: Optional[str] = None, | |
| max_length: int = 512, | |
| temperature: float = 1.0, | |
| top_p: float = 1.0, | |
| num_beams: int = 1 | |
| ): | |
| """ | |
| Initialize the LLM interface. | |
| Args: | |
| model_name: Name of the Hugging Face model to use | |
| model_type: Type of model ('causal' or 'seq2seq') | |
| device: Device to run the model on ('cpu' or 'cuda') | |
| cache_dir: Directory to cache models | |
| max_length: Maximum length of generated sequences | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| num_beams: Number of beams for beam search | |
| """ | |
| self.model_name = model_name | |
| self.model_type = model_type | |
| self.device = device | |
| self.cache_dir = cache_dir | |
| self.max_length = max_length | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.num_beams = num_beams | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir | |
| ) | |
| # Load model based on type | |
| if model_type == "causal": | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ).to(device) | |
| elif model_type == "seq2seq": | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ).to(device) | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| # Response cache for efficiency | |
| self.response_cache = {} | |
| def generate( | |
| self, | |
| prompt: str, | |
| num_samples: int = 1, | |
| return_logits: bool = False, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate responses from the model with uncertainty quantification. | |
| Args: | |
| prompt: Input text prompt | |
| num_samples: Number of samples to generate (for MC methods) | |
| return_logits: Whether to return token logits | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| Dictionary containing: | |
| - response: The generated text | |
| - samples: Multiple samples if num_samples > 1 | |
| - logits: Token logits if return_logits is True | |
| """ | |
| # Check cache first | |
| cache_key = (prompt, num_samples, return_logits, str(kwargs)) | |
| if cache_key in self.response_cache: | |
| return self.response_cache[cache_key] | |
| # Prepare inputs | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| # Set generation parameters | |
| gen_kwargs = { | |
| "max_length": self.max_length, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "num_beams": self.num_beams, | |
| "do_sample": self.temperature > 0, | |
| "pad_token_id": self.tokenizer.eos_token_id | |
| } | |
| gen_kwargs.update(kwargs) | |
| # Generate multiple samples if requested | |
| samples = [] | |
| all_logits = [] | |
| for _ in range(num_samples): | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| output_scores=return_logits, | |
| return_dict_in_generate=True, | |
| **gen_kwargs | |
| ) | |
| # Extract generated tokens | |
| if self.model_type == "causal": | |
| gen_tokens = outputs.sequences[0, inputs.input_ids.shape[1]:] | |
| else: | |
| gen_tokens = outputs.sequences[0] | |
| # Decode tokens to text | |
| gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True) | |
| samples.append(gen_text) | |
| # Extract logits if requested | |
| if return_logits and hasattr(outputs, "scores"): | |
| all_logits.append([score.cpu().numpy() for score in outputs.scores]) | |
| # Prepare result | |
| result = { | |
| "response": samples[0], # Primary response is first sample | |
| "samples": samples | |
| } | |
| if return_logits: | |
| result["logits"] = all_logits | |
| # Cache result | |
| self.response_cache[cache_key] = result | |
| return result | |
| def batch_generate( | |
| self, | |
| prompts: List[str], | |
| **kwargs | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Generate responses for a batch of prompts. | |
| Args: | |
| prompts: List of input text prompts | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| List of generation results for each prompt | |
| """ | |
| results = [] | |
| for prompt in tqdm(prompts, desc="Generating responses"): | |
| results.append(self.generate(prompt, **kwargs)) | |
| return results | |
| def clear_cache(self): | |
| """Clear the response cache.""" | |
| self.response_cache = {} | |