Spaces:
Sleeping
Sleeping
import os | |
import collections | |
import string | |
import re | |
import numpy as np | |
from datasets import load_dataset, load_metric | |
from transformers import ( | |
DebertaTokenizerFast, | |
DebertaForQuestionAnswering, | |
Trainer, | |
TrainingArguments, | |
default_data_collator, | |
) | |
from peft import LoraConfig, get_peft_model | |
from huggingface_hub import login | |
# Load your HF token securely from environment variable | |
hf_token = os.environ.get("roberta_token") | |
if hf_token: | |
login(token=hf_token) | |
else: | |
print("Warning: HF token not found in environment variable 'roberta_token'. Push to hub may fail.") | |
metric = load_metric("squad") | |
def normalize_answer(s): | |
"""Lower text and remove punctuation/articles/extra whitespace""" | |
def remove_articles(text): | |
return re.sub(r'\b(a|an|the)\b', ' ', text) | |
def white_space_fix(text): | |
return ' '.join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return ''.join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def prepare_train_features(examples, tokenizer, max_length=512, doc_stride=128): | |
tokenized_examples = tokenizer( | |
examples["question"], | |
examples["context"], | |
truncation="only_second", | |
max_length=max_length, | |
stride=doc_stride, | |
return_overflowing_tokens=True, | |
return_offsets_mapping=True, | |
padding="max_length", | |
) | |
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") | |
offset_mapping = tokenized_examples.pop("offset_mapping") | |
start_positions = [] | |
end_positions = [] | |
for i, offsets in enumerate(offset_mapping): | |
input_ids = tokenized_examples["input_ids"][i] | |
cls_index = input_ids.index(tokenizer.cls_token_id) | |
sample_index = sample_mapping[i] | |
answers = examples["answers"][sample_index] | |
if len(answers["answer_start"]) == 0: | |
start_positions.append(cls_index) | |
end_positions.append(cls_index) | |
else: | |
start_char = answers["answer_start"][0] | |
end_char = start_char + len(answers["text"][0]) | |
sequence_ids = tokenized_examples.sequence_ids(i) | |
token_start_index = 0 | |
while sequence_ids[token_start_index] != 1: | |
token_start_index += 1 | |
token_end_index = len(input_ids) - 1 | |
while sequence_ids[token_end_index] != 1: | |
token_end_index -= 1 | |
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): | |
start_positions.append(cls_index) | |
end_positions.append(cls_index) | |
else: | |
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: | |
token_start_index += 1 | |
start_positions.append(token_start_index - 1) | |
while offsets[token_end_index][1] >= end_char: | |
token_end_index -= 1 | |
end_positions.append(token_end_index + 1) | |
tokenized_examples["start_positions"] = start_positions | |
tokenized_examples["end_positions"] = end_positions | |
return tokenized_examples | |
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30): | |
all_start_logits, all_end_logits = raw_predictions | |
example_id_to_index = {k: i for i, k in enumerate(examples["id"])} | |
features_per_example = collections.defaultdict(list) | |
for i, feature in enumerate(features): | |
features_per_example[example_id_to_index[feature["example_id"]]].append(i) | |
predictions = collections.OrderedDict() | |
for example_index, example in enumerate(examples): | |
feature_indices = features_per_example[example_index] | |
min_null_score = None | |
valid_answers = [] | |
context = example["context"] | |
for feature_index in feature_indices: | |
start_logits = all_start_logits[feature_index] | |
end_logits = all_end_logits[feature_index] | |
offsets = features[feature_index]["offset_mapping"] | |
cls_index = features[feature_index]["input_ids"].index(features[feature_index]["cls_token_id"]) | |
feature_null_score = start_logits[cls_index] + end_logits[cls_index] | |
if min_null_score is None or min_null_score > feature_null_score: | |
min_null_score = feature_null_score | |
start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist() | |
end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist() | |
for start_index in start_indexes: | |
for end_index in end_indexes: | |
if ( | |
start_index >= len(offsets) | |
or end_index >= len(offsets) | |
or offsets[start_index] is None | |
or offsets[end_index] is None | |
): | |
continue | |
if end_index < start_index or end_index - start_index + 1 > max_answer_length: | |
continue | |
start_char = offsets[start_index][0] | |
end_char = offsets[end_index][1] | |
valid_answers.append( | |
{"score": start_logits[start_index] + end_logits[end_index], "text": context[start_char:end_char]} | |
) | |
best_answer = max(valid_answers, key=lambda x: x["score"]) if valid_answers else {"text": "", "score": 0.0} | |
predictions[example["id"]] = best_answer["text"] | |
return predictions | |
def compute_metrics(p, tokenizer, examples, features): | |
predictions = postprocess_qa_predictions(examples, features, p.predictions) | |
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] | |
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples] | |
return metric.compute(predictions=formatted_predictions, references=references) | |
def main(): | |
model_name = "microsoft/deberta-xlarge" | |
output_dir = "./deberta-lora-cuad-finetuned" | |
datasets = load_dataset("theatticusproject/cuad-qa") | |
tokenizer = DebertaTokenizerFast.from_pretrained(model_name) | |
model = DebertaForQuestionAnswering.from_pretrained(model_name) | |
# LoRA config: tune rank and dropout as needed | |
lora_config = LoraConfig( | |
r=8, | |
lora_alpha=32, | |
target_modules=["query", "value"], # Adjust for DeBERTa internals as needed | |
lora_dropout=0.1, | |
bias="none", | |
task_type="QUESTION_ANSWERING" | |
) | |
model = get_peft_model(model, lora_config) | |
train_dataset = datasets["train"].map( | |
lambda examples: prepare_train_features(examples, tokenizer), | |
batched=True, | |
remove_columns=datasets["train"].column_names, | |
) | |
val_dataset = datasets["validation"].map( | |
lambda examples: prepare_train_features(examples, tokenizer), | |
batched=True, | |
remove_columns=datasets["validation"].column_names, | |
) | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
evaluation_strategy="steps", | |
eval_steps=500, | |
save_steps=500, | |
save_total_limit=2, | |
learning_rate=3e-4, # LoRA usually supports higher LR | |
per_device_train_batch_size=1, | |
per_device_eval_batch_size=1, | |
num_train_epochs=3, | |
weight_decay=0.0, | |
logging_dir=f"{output_dir}/logs", | |
logging_steps=100, | |
load_best_model_at_end=True, | |
metric_for_best_model="eval_f1", | |
greater_is_better=True, | |
fp16=True, | |
push_to_hub=True, | |
hub_model_id="AvocadoMuffin/deberta_finetuned_qa_lora", | |
hub_strategy="checkpoint", | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
tokenizer=tokenizer, | |
data_collator=default_data_collator, | |
compute_metrics=lambda p: compute_metrics(p, tokenizer, datasets["validation"], val_dataset), | |
) | |
trainer.train() | |
trainer.push_to_hub() | |
if __name__ == "__main__": | |
main() | |