Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,7 +14,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
| 14 |
image_size = 384
|
| 15 |
|
| 16 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 17 |
-
std=[0.229, 0.224, 0.225
|
|
|
|
| 18 |
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
|
| 19 |
|
| 20 |
#######Tag2Text Model
|
|
@@ -41,7 +42,8 @@ def inference(raw_image, model_n , input_tag):
|
|
| 41 |
image = transform(raw_image).unsqueeze(0).to(device)
|
| 42 |
if model_n == 'Recognize Anything Model':
|
| 43 |
model = model_ram
|
| 44 |
-
|
|
|
|
| 45 |
return tags[0],tags_chinese[0], 'none'
|
| 46 |
else:
|
| 47 |
model = model_tag2text
|
|
|
|
| 14 |
image_size = 384
|
| 15 |
|
| 16 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 17 |
+
std=[0.229, 0.224, 0.225
|
| 18 |
+
])
|
| 19 |
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
|
| 20 |
|
| 21 |
#######Tag2Text Model
|
|
|
|
| 42 |
image = transform(raw_image).unsqueeze(0).to(device)
|
| 43 |
if model_n == 'Recognize Anything Model':
|
| 44 |
model = model_ram
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
tags, tags_chinese = model.generate_tag(image)
|
| 47 |
return tags[0],tags_chinese[0], 'none'
|
| 48 |
else:
|
| 49 |
model = model_tag2text
|