Update modeling_llama.py
Browse files- modeling_llama.py +2 -4
modeling_llama.py
CHANGED
@@ -1114,12 +1114,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1114 |
def get_decoder(self):
|
1115 |
return self.model
|
1116 |
|
1117 |
-
|
1118 |
-
if torch.any(input_ids == self.shutdown_token_id):
|
1119 |
-
return True
|
1120 |
def detect_shutdown_token(self, input_ids):
|
1121 |
shutdown_token_tensor = torch.tensor(self.shutdown_token_id, device=input_ids.device, dtype=input_ids.dtype)
|
1122 |
-
|
1123 |
return True
|
1124 |
return False
|
1125 |
|
|
|
1114 |
def get_decoder(self):
|
1115 |
return self.model
|
1116 |
|
1117 |
+
|
|
|
|
|
1118 |
def detect_shutdown_token(self, input_ids):
|
1119 |
shutdown_token_tensor = torch.tensor(self.shutdown_token_id, device=input_ids.device, dtype=input_ids.dtype)
|
1120 |
+
if torch.any(input_ids == shutdown_token_tensor):
|
1121 |
return True
|
1122 |
return False
|
1123 |
|