File size: 1,910 Bytes
0e26089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import pipeline
from datasets import load_dataset

# Initialize the pipeline with the Llama 3.2 model
model = pipeline("text-generation", model="meta-llama/Llama-3.2-1B")

def load_data():
    """
    Load the dataset from the specified source.

    Returns:
    - Dataset object containing the loaded data.
    """
    try:
        ds = load_dataset("neuralwork/arxiver")
        return ds
    except Exception as e:
        print(f"An error occurred while loading the dataset: {e}")
        return None

def generate_text(prompt, max_length=50, num_return_sequences=1, temperature=1.0):
    """
    Generate text using the Llama 3.2 model.

    Parameters:
    - prompt (str): The input prompt for text generation.
    - max_length (int): The maximum length of the generated text.
    - num_return_sequences (int): The number of sequences to return.
    - temperature (float): Controls the randomness of predictions. Lower values make the output more deterministic.

    Returns:
    - List of generated text sequences.
    """
    try:
        output = model(prompt, max_length=max_length, num_return_sequences=num_return_sequences, temperature=temperature)
        return [o['generated_text'] for o in output]
    except Exception as e:
        print(f"An error occurred: {e}")
        return []

# Example usage
if __name__ == "__main__":
    # Load the dataset
    dataset = load_data()
    if dataset:
        print("Dataset loaded successfully.")
        # You can access specific splits of the dataset, e.g., dataset['train']
        print(dataset['train'][0])  # Print the first example from the training set

    prompt = "Describe the process of synaptic transmission in the brain."
    generated_texts = generate_text(prompt, max_length=100, num_return_sequences=3, temperature=0.7)
    for i, text in enumerate(generated_texts):
        print(f"Generated Text {i+1}:\n{text}\n")