boltuix commited on
Commit
7dff5aa
·
verified ·
1 Parent(s): 888ffec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -3
README.md CHANGED
@@ -348,10 +348,15 @@ To adapt `bert-mini` for custom tasks (e.g., specific IoT commands):
348
 
349
  # Tokenize dataset
350
  def tokenize_function(examples):
351
- return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)
 
352
 
 
353
  tokenized_dataset = dataset.map(tokenize_function, batched=True)
354
- tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
 
 
 
355
 
356
  # Define training arguments
357
  training_args = TrainingArguments(
@@ -388,7 +393,7 @@ To adapt `bert-mini` for custom tasks (e.g., specific IoT commands):
388
  outputs = model(**inputs)
389
  logits = outputs.logits
390
  predicted_class = torch.argmax(logits, dim=1).item()
391
- print(f"Predicted class for '{text}': {'Valid IoT Command' if predicted_class == 1 else 'Invalid Command'}")
392
  ```
393
  3. **Deploy**: Export to ONNX or TensorFlow Lite for edge devices.
394
 
 
348
 
349
  # Tokenize dataset
350
  def tokenize_function(examples):
351
+ # Use return_tensors="pt" here to get PyTorch tensors directly
352
+ return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64, return_tensors="pt")
353
 
354
+ # Pass batched=True to the map function as the tokenize_function is designed to handle batches
355
  tokenized_dataset = dataset.map(tokenize_function, batched=True)
356
+ # We don't need to set the format to "torch" explicitly here anymore
357
+ # because the tokenizer is already returning PyTorch tensors.
358
+ # tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
359
+
360
 
361
  # Define training arguments
362
  training_args = TrainingArguments(
 
393
  outputs = model(**inputs)
394
  logits = outputs.logits
395
  predicted_class = torch.argmax(logits, dim=1).item()
396
+ print(f"Predicted class for '{text}': {'Valid IoT Command' if predicted_class == 1 else 'Invalid Command'}")
397
  ```
398
  3. **Deploy**: Export to ONNX or TensorFlow Lite for edge devices.
399