Update CXR_LLAVA_HF.py
Browse files- CXR_LLAVA_HF.py +1 -1
CXR_LLAVA_HF.py
CHANGED
@@ -616,7 +616,7 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
616 |
image_args = {"images": images}
|
617 |
do_sample = True if temperature > 0.001 else False
|
618 |
num_image_tokens = 1
|
619 |
-
max_context_length = getattr(self.config, 'max_position_embeddings', 2048
|
620 |
|
621 |
max_new_tokens = min(512, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
622 |
thread = Thread(target=self.generate, kwargs=dict(
|
|
|
616 |
image_args = {"images": images}
|
617 |
do_sample = True if temperature > 0.001 else False
|
618 |
num_image_tokens = 1
|
619 |
+
max_context_length = getattr(self.config, 'max_position_embeddings', 1024) # 2048->1024
|
620 |
|
621 |
max_new_tokens = min(512, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
622 |
thread = Thread(target=self.generate, kwargs=dict(
|