--- tags: - dnabert - bacteria - kmer - translation-initiation-site - sequence-modeling - DNA library_name: transformers --- # BacteriaTIS-DNABERT-K6-89M This model, `BacteriaTIS-DNABERT-K6-89M`, is a **DNA sequence classifier** based on **DNABERT** trained for **Translation Initiation Site (TIS) classification** in bacterial genomes. It operates on **6-mer tokenized sequences** derived from a **60 bp window (30 bp upstream + 30 bp downstream)** around the TIS. The model was fine-tuned using **89M trainable parameters**. ## Model Details - **Base Model:** DNABERT - **Task:** Translation Initiation Site (TIS) Classification - **K-mer Size:** 6 - **Input Sequence Window:** 60 bp (30 upstream + 30 downstream) of TIS site in ORF sequence - **Number of Trainable Parameters:** 89M - **Max Sequence Length:** 512 - **Precision Used:** AMP (Automatic Mixed Precision) --- ### **Install Dependencies** Ensure you have `transformers` and `torch` installed: ```bash pip install torch transformers ``` ### **Load Model & Tokenizer** ```python import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer # Load Model model_checkpoint = "Genereux-akotenou/BacteriaTIS-DNABERT-K6-89M" model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) ``` ### **Inference Example** To classify a TIS, extract a 60 bp sequence window (30 bp upstream + 30 bp downstream) of the TIS codon site and convert it to 6-mers: ```python def generate_kmer(sequence: str, k: int, overlap: int = 1): """Generate k-mer encoding from DNA sequence.""" return " ".join([sequence[j:j+k] for j in range(0, len(sequence) - k + 1, overlap)]) # Example TIS-centered sequence (60 bp window) sequence = "AGAACCAGCCGGAGACCTCCTGCTCGTACATGAAAGGCTCGAGCAGCCGGGCGAGGGCGG" seq_kmer = generate_kmer(sequence, k=6) ``` ### **Run Model** ```python # Tokenize input inputs = tokenizer( seq_kmer, return_tensors="pt", max_length=tokenizer.model_max_length, padding="max_length", truncation=True ) # Run inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class = torch.argmax(logits, dim=-1).item() ```