|
# Model Card: BART-Based Content Generation Model |
|
|
|
## Model Overview |
|
|
|
This model is a fine-tuned version of `facebook/bart-base` trained for content generation tasks. It has been optimized for high-quality text generation while maintaining efficiency. |
|
|
|
## Model Details |
|
|
|
- **Model Architecture:** BART |
|
- **Base Model:** `facebook/bart-base` |
|
- **Task:** Content Generation |
|
- **Dataset:** cnn_dailymail |
|
- **Framework:** Hugging Face Transformers |
|
- **Training Hardware:** CUDA |
|
|
|
## Installation |
|
|
|
To use the model, install the necessary dependencies: |
|
|
|
```sh |
|
pip install transformers torch datasets evaluate |
|
``` |
|
|
|
## Usage |
|
|
|
### Load the Model and Tokenizer |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
# Load fine-tuned model |
|
model_path = "fine_tuned_model" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
# Define test text |
|
input_text = "Technology" |
|
inputs = tokenizer(input_text, return_tensors="pt").to(device) |
|
|
|
# Generate output |
|
with torch.no_grad(): |
|
output_ids = model.generate(**inputs) |
|
output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] |
|
|
|
print(f"Generated Content: {output_text}") |
|
``` |
|
|
|
## Training Details |
|
|
|
### Data Preprocessing |
|
The dataset was split into: |
|
- **Train:** 80% |
|
- **Validation:** 10% |
|
- **Test:** 10% |
|
|
|
Tokenization was applied using the `facebook/bart-base` tokenizer with truncation and padding. |
|
|
|
### Fine-Tuning |
|
- **Epochs:** 3 |
|
- **Batch Size:** 4 |
|
- **Learning Rate:** 2e-5 |
|
- **Weight Decay:** 0.01 |
|
- **Evaluation Strategy:** Epoch-wise |
|
|
|
## Evaluation Metrics |
|
The model was evaluated using the ROUGE metric: |
|
```python |
|
import evaluate |
|
rouge = evaluate.load("rouge") |
|
|
|
# Example evaluation |
|
references = ["The generated story was highly creative and engaging."] |
|
predictions = ["The output was imaginative and captivating."] |
|
results = rouge.compute(predictions=predictions, references=references) |
|
print("Evaluation Metrics (ROUGE):", results) |
|
``` |
|
|
|
## Performance |
|
- **ROUGE Score:** Achieved competitive scores for content generation quality |
|
- **Inference Speed:** Optimized for efficient text generation |
|
- **Generalization:** Works well on diverse text generation tasks but may require domain-specific fine-tuning. |
|
|
|
## Limitations |
|
- May generate slightly verbose or overly detailed content in some cases. |
|
- Requires GPU for optimal performance. |
|
|
|
## Future Improvements |
|
- Experiment with larger models like `bart-large` for enhanced generation quality. |
|
- Fine-tune on domain-specific datasets for better adaptation to specific content types. |
|
|
|
|