File size: 15,428 Bytes
9ff98d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""
Demo that takes an iNaturalist taxa ID as input and generates a prediction 
for each location on the globe and saves the ouput as an image.
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import argparse

import utils
import datasets
import eval
import create_inputs_to_fs_sinr 

text_model = './experiments/gpt_data.pt'

def extract_grit_token(model, text:str):
    def gritlm_instruction(instruction):
        return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
    d_rep = model.encode([text], instruction=gritlm_instruction(""))
    d_rep = torch.from_numpy(d_rep)
    return d_rep

def choose_context_points_from_map(eval_params):
    context_points = []

    if False:
        def onclick(event):
            if event.xdata is not None and event.ydata is not None:
                # Convert image coordinates to normalized geographical coordinates
                lon = event.xdata / mask.shape[1] * 2 - 1
                lat = 1 - event.ydata / mask.shape[0] * 2
                context_points.append((lon, lat))
                print(f"Added context point: ({lon}, {lat})")

        # Load ocean mask
        with open('paths.json', 'r') as f:
            paths = json.load(f)
        if eval_params['high_res']:
            mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy'))
        else:
            mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))

        mask_inds = np.where(mask.reshape(-1) == 1)[0]

        # # Generate input features
        # locs = utils.coord_grid(mask.shape)
        # if not eval_params['disable_ocean_mask']:
        #     locs = locs[mask_inds, :]
        # locs = torch.from_numpy(locs)

        # Reshape and create masked array for visualization
        op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan  # Set to NaN
        op_im[mask_inds] = 0  # Placeholder for the mask visualization
        op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
        op_im = np.ma.masked_invalid(op_im)

        # Set color for masked values
        cmap = plt.cm.plasma
        cmap.set_bad(color='none')
        plt.ioff()
        # Display the map and capture context points
        fig, ax = plt.subplots(figsize=(6, 3), dpi=334)  # Define the figure size
        ax.imshow(op_im, cmap=cmap, interpolation='nearest')  # Display the image
        ax.axis('off')  # Turn off the axis

        # Connect the onclick event to the handler
        cid = fig.canvas.mpl_connect('button_press_event', onclick)

        plt.show(block=True)  # Block execution until the window is closed

        print(f"Context points collected: {context_points}")

    else:
        #USA
        #TODO: 37.541170, -92.003293 1. flip order, then 2. normalize so divide by 180 and 90
        context_points = [(-0.5884012559178662, 0.46394662490802496), (-0.5451199953511522, 0.4504212309809269),
         (-0.5437674559584422, 0.5342786733289353), (-0.589753795310576, 0.5342786733289353)]
        print(f"Context points collected: {context_points}")
    return context_points

def main(eval_params):
     # load params
    with open('paths.json', 'r') as f:
        paths = json.load(f)

    ckp_name = os.path.split(eval_params['model_path'])[-1]
    experiment_name = os.path.split(os.path.split(eval_params['model_path'])[-2])[-1]

    eval_overrides = {'ckp_name':ckp_name,
                      'experiment_name':experiment_name,
                      'device':eval_params['device']}


    train_overrides = {'dataset': 'eval_transformer'}
    #grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding")
    #grit_gpt = torch.load(text_model, map_location='cpu')
    #context_model = torch.load("experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt", map_location=torch.device('cpu'))
    context_data = np.load('data/positive_eval_data.npz')
    text_type_value = 0

    for pt in eval_params['context_pt_trial']:
        number_of_context_points = pt
        if eval_params['choose_context_points'] == 1:
            #context_points = choose_context_points_from_map(eval_params)        
            text_emb, text_type_value = create_inputs_to_fs_sinr.use_pregenerated_textemb_fromchris(taxon_id=eval_params['test_taxa'], 
                                                                                                    text_type=eval_params['text_type'])
            context_points = create_inputs_to_fs_sinr.get_eval_context_points(taxa_id=eval_params['test_taxa'], 
                                                                              context_data=context_data, 
                                                                              size=number_of_context_points)
            model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embedding_from_given_points(
                                                                        context_points=context_points,
                                                                        overrides=eval_overrides,
                                                                        taxa_of_interest=eval_params['taxa_id'],
                                                                        train_overrides=train_overrides,
                                                                        text_emb=text_emb)
            #TODO: why is taxa_id updated to 'selected pts'?? 
            eval_params['taxa_id'] = 'selected_points'
        else:
            model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embeddings(
                                                                            overrides=eval_overrides,
                                                                            taxa_of_interest=eval_params['taxa_id'],
                                                                            num_context=eval_params['num_context'],
                                                                            train_overrides=train_overrides)

        if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
            raster = datasets.load_env()
        else:
            raster = None
        enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim'])
        enc_time = utils.CoordEncoder('sin_cos', raster=None, input_dim=2 * train_params['params']['input_time_dim'])

        # load ocean mask
        if eval_params['high_res']:
            mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy'))
        else:
            mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))
        #mask = 0*mask+1
        mask_inds = np.where(mask.reshape(-1) == 1)[0]
            
        # generate input features
        locs = utils.coord_grid(mask.shape)
        if not eval_params['disable_ocean_mask']:
            locs = locs[mask_inds, :]
        locs = torch.from_numpy(locs)
        locs_enc = enc.encode(locs).to(eval_params['device'])
        if train_params['params']['input_time_dim'] > 0:
            extra_input = torch.cat([enc_time.encode(torch.tensor([[0.0]]), normalize=False), torch.tensor([[1.0]])],
                                    dim=1).to(eval_params['device'])
            locs_enc = torch.cat((locs_enc, extra_input.repeat(locs_enc.shape[0], 1)), dim=1)

        with torch.no_grad():
            # Here if we set eval to False we will see what the ema embeddings look like (currently as ema is 1.0 this is just the last training example seen)
            preds = model.embedding_forward(x=locs_enc, class_ids=None, return_feats=False, class_of_interest=class_of_interest, eval=True).cpu().numpy()

        # threshold predictions
        if eval_params['threshold'] > 0:
            print(f'Applying threshold of {eval_params["threshold"]} to the predictions.')
            preds[preds<eval_params['threshold']] = 0.0
            preds[preds>=eval_params['threshold']] = 1.0
            
        # mask data
        if not eval_params['disable_ocean_mask']:
            op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan  # set to NaN
            op_im[mask_inds] = preds
        else:
            op_im = preds

        # reshape and create masked array for visualization
        op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
        op_im = np.ma.masked_invalid(op_im) 

        # set color for masked values
        cmap = plt.cm.plasma
        cmap.set_bad(color='none')
        if eval_params['set_max_cmap_to_1']:
            vmax = 1.0
        else:
            vmax = np.max(op_im)

        # # Display the image
        # if eval_params['show_map'] == 1:
        #     fig, ax = plt.subplots()
        #     cax = ax.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap)
        #     fig.colorbar(cax)
        #     plt.show(block=True)  # Set block=True to block code execution until the window is closed

        if eval_params['show_map'] == 1:
            # Display the image
            fig, ax = plt.subplots(figsize=(6,3), dpi=334)
            plt.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap, interpolation='nearest')  # Display the image
            plt.axis('off')  # Turn off the axis

            if eval_params['show_context_points'] == 1:
                # Convert the tensor to numpy array if it's not already
                context_locs = context_locs_of_interest.numpy() if isinstance(context_locs_of_interest, torch.Tensor) else context_locs_of_interest
                # Convert context locations directly to image coordinates
                #delete our dumby context point (at 0,0)
                image_x = (context_locs[1:, 0] + 1) / 2 * op_im.shape[1]  # Scale longitude from [-1, 1] to [0, image width]
                image_y = (1 - (context_locs[1:, 1] + 1) / 2) * op_im.shape[
                    0]  # Scale latitude from [-1, 1] to [0, image height]

                from matplotlib.offsetbox import OffsetImage, AnnotationBbox
                # Plot the context locations
                def getImage(path):
                    return OffsetImage(plt.imread(path), zoom=.04)

                for x0, y0 in zip(image_x, image_y):
                    ab = AnnotationBbox(getImage('black_circle.png'), (x0, y0), frameon=False)
                    ax.add_artist(ab)
                #plt.scatter(image_x, image_y, c='green', s=30, marker=r'$\checkmark$')  # Adjust color and size of the point

            #plt.show(block=True)  # Block execution until the window is closed


        exp_name = eval_params['model_path'].split(os.path.sep)[-2]

        # save image
        #save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png')
        #save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png')
        #save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '(' + str(text_type_value) + ')_' + str(number_of_context_points) +'.png'
        save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '_' + str(number_of_context_points) +'.png'
        print(f'Saving image to {save_loc}')
        plt.savefig(save_loc, bbox_inches='tight', pad_inches=0, dpi=334)
        # plt.imsave(fname=save_loc, arr=op_im, vmin=0, vmax=vmax, cmap=cmap)
        plt.show(block=False)  # Block execution until the window is closed
    
    return True


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        
    info_str = '\nDemo that takes an iNaturalist taxa ID as input and ' + \
               'generates a predicted range for each location on the globe ' + \
               'and saves the ouput as an image.\n\n' + \
               'Warning: these estimated ranges should be validated before use.'  
               
    parser = argparse.ArgumentParser(usage=info_str)
    # parser.add_argument('--model_path', type=str, default='./pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt')
    # parser.add_argument('--model_path', type=str, default='./experiments/transformer_ema_1.0/model_10.pt')
    # parser.add_argument('--model_path', type=str, default='./experiments/03_08_coord_multihead.pt/model.pt')
    # parser.add_argument('--model_path', type=str, default='./experimentvs/coord_context_20_without_registry/model_best.pt')
    # parser.add_argument('--model_path', type=str, default='./experiments/coord_sinr_inputs_context_20_without_registry/model_best.pt')
    parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt')
    #parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt')
    # parser.add_argument('--taxa_id', type=int, default=144575, help='iNaturalist taxon ID.')
    # parser.add_argument('--taxa_id', type=int, default=9083, help='iNaturalist taxon ID.')
    parser.add_argument('--taxa_id', type=int, default=3352, help='iNaturalist taxon ID.')
    parser.add_argument('--threshold', type=float, default=-1, help='Threshold the range map [0, 1].')
    parser.add_argument('--op_path', type=str, default='./images/', help='Location where the output image will be saved.')
    parser.add_argument('--rand_taxa', action='store_true', help='Select a random taxa.')
    parser.add_argument('--high_res', action='store_true', help='Generate higher resolution output.')
    parser.add_argument('--disable_ocean_mask', action='store_true', help='Do not use an ocean mask.')
    parser.add_argument('--set_max_cmap_to_1', action='store_true', help='Consistent maximum intensity ouput.')
    parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda')
    #parser.add_argument('--device', type=str, default='cuda:3', help='cpu or cuda')
    parser.add_argument('--show_map', type=int, default=1, help='shows the map if 1')
    parser.add_argument('--show_context_points', type=int, default=1, help='also plots context points if 1')
    parser.add_argument('--prefix', type=str, default='')
    parser.add_argument('--num_context', type=int, default=5) 
    parser.add_argument('--choose_context_points', type=int, default=1)
    parser.add_argument('--additional_save_name', type=str, default="")
    #taxas: black&whitewarbler(10286), hyacinth macaw(18938), yellow baboon(67683)
    # bawnswallow (11901), pika(43188), loon(4626), eurorobin(13094)
    # southernflyingsquirrel (46272)
    parser.add_argument('--taxa_name', type=str, default='sfs', help='Name of the taxon.')
    parser.add_argument('--test_taxa', type=int, default=46272, help='Taxon ID to test.')
    parser.add_argument('--text_type', type=str, default='range', help='Type of text for input.')
    parser.add_argument('--context_pt_trial', type=int, nargs='+', default=[0, 1, 2, 5, 10, 20], help='List of context points for trial.')
    eval_params = vars(parser.parse_args())

    if not os.path.isdir(eval_params['op_path']):
        os.makedirs(eval_params['op_path'])

    eval_params['high_res'] = True

    main(eval_params)