|
from transformers import pipeline |
|
from datasets import load_dataset |
|
|
|
|
|
model = pipeline("text-generation", model="meta-llama/Llama-3.2-1B") |
|
|
|
def load_data(): |
|
""" |
|
Load the dataset from the specified source. |
|
|
|
Returns: |
|
- Dataset object containing the loaded data. |
|
""" |
|
try: |
|
ds = load_dataset("neuralwork/arxiver") |
|
return ds |
|
except Exception as e: |
|
print(f"An error occurred while loading the dataset: {e}") |
|
return None |
|
|
|
def generate_text(prompt, max_length=50, num_return_sequences=1, temperature=1.0): |
|
""" |
|
Generate text using the Llama 3.2 model. |
|
|
|
Parameters: |
|
- prompt (str): The input prompt for text generation. |
|
- max_length (int): The maximum length of the generated text. |
|
- num_return_sequences (int): The number of sequences to return. |
|
- temperature (float): Controls the randomness of predictions. Lower values make the output more deterministic. |
|
|
|
Returns: |
|
- List of generated text sequences. |
|
""" |
|
try: |
|
output = model(prompt, max_length=max_length, num_return_sequences=num_return_sequences, temperature=temperature) |
|
return [o['generated_text'] for o in output] |
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
return [] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dataset = load_data() |
|
if dataset: |
|
print("Dataset loaded successfully.") |
|
|
|
print(dataset['train'][0]) |
|
|
|
prompt = "Describe the process of synaptic transmission in the brain." |
|
generated_texts = generate_text(prompt, max_length=100, num_return_sequences=3, temperature=0.7) |
|
for i, text in enumerate(generated_texts): |
|
print(f"Generated Text {i+1}:\n{text}\n") |