Rohith1112 commited on
Commit
0a7c920
·
verified ·
1 Parent(s): 6117b43

Upload gradio copy.ipynb

Browse files
Files changed (1) hide show
  1. gradio copy.ipynb +170 -0
gradio copy.ipynb ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "build_sam_vit_3d...\n"
13
+ ]
14
+ },
15
+ {
16
+ "name": "stderr",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "/Users/rohith/Desktop/M3D copy/llama/lib/python3.10/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.networks.blocks.patchembedding PatchEmbeddingBlock.__init__:pos_embed: Argument `pos_embed` has been deprecated since version 1.2. It will be removed in version 1.4. please use `proj_type` instead.\n",
20
+ " warn_deprecated(argname, msg, warning_category)\n",
21
+ "/Users/rohith/Desktop/M3D copy/llama/lib/python3.10/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.networks.nets.vit ViT.__init__:pos_embed: Argument `pos_embed` has been deprecated since version 1.2. It will be removed in version 1.4. please use `proj_type` instead.\n",
22
+ " warn_deprecated(argname, msg, warning_category)\n"
23
+ ]
24
+ },
25
+ {
26
+ "data": {
27
+ "application/vnd.jupyter.widget-view+json": {
28
+ "model_id": "6884c5d2441e4a8a9f18b6e0595d119e",
29
+ "version_major": 2,
30
+ "version_minor": 0
31
+ },
32
+ "text/plain": [
33
+ "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
34
+ ]
35
+ },
36
+ "metadata": {},
37
+ "output_type": "display_data"
38
+ },
39
+ {
40
+ "name": "stdout",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "* Running on local URL: http://127.0.0.1:7860\n",
44
+ "\n",
45
+ "To create a public link, set `share=True` in `launch()`.\n"
46
+ ]
47
+ },
48
+ {
49
+ "data": {
50
+ "text/html": [
51
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
52
+ ],
53
+ "text/plain": [
54
+ "<IPython.core.display.HTML object>"
55
+ ]
56
+ },
57
+ "metadata": {},
58
+ "output_type": "display_data"
59
+ },
60
+ {
61
+ "data": {
62
+ "text/plain": []
63
+ },
64
+ "execution_count": 1,
65
+ "metadata": {},
66
+ "output_type": "execute_result"
67
+ },
68
+ {
69
+ "name": "stderr",
70
+ "output_type": "stream",
71
+ "text": [
72
+ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
73
+ "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n"
74
+ ]
75
+ },
76
+ {
77
+ "name": "stdout",
78
+ "output_type": "stream",
79
+ "text": [
80
+ "Using existing dataset file at: .gradio/flagged/dataset1.csv\n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "import numpy as np\n",
86
+ "import torch\n",
87
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
88
+ "import simple_slice_viewer as ssv\n",
89
+ "import SimpleITK as sikt\n",
90
+ "import gradio as gr\n",
91
+ "\n",
92
+ "device = torch.device('cpu') # Set to 'cuda' if using a GPU\n",
93
+ "dtype = torch.float32 # Data type for model processing\n",
94
+ "\n",
95
+ "model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'\n",
96
+ "proj_out_num = 256 # Number of projection outputs required for the image\n",
97
+ "\n",
98
+ "# Load model and tokenizer\n",
99
+ "model = AutoModelForCausalLM.from_pretrained(\n",
100
+ " model_name_or_path,\n",
101
+ " torch_dtype=torch.float32,\n",
102
+ " device_map='cpu',\n",
103
+ " trust_remote_code=True\n",
104
+ ")\n",
105
+ "\n",
106
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
107
+ " model_name_or_path,\n",
108
+ " model_max_length=512,\n",
109
+ " padding_side=\"right\",\n",
110
+ " use_fast=False,\n",
111
+ " trust_remote_code=True\n",
112
+ ")\n",
113
+ "\n",
114
+ "def process_image(image_path, question):\n",
115
+ " # Load the image\n",
116
+ " image_np = np.load(image_path) # Load the .npy image\n",
117
+ " image_tokens = \"<im_patch>\" * proj_out_num\n",
118
+ " input_txt = image_tokens + question\n",
119
+ " input_id = tokenizer(input_txt, return_tensors=\"pt\")['input_ids'].to(device=device)\n",
120
+ "\n",
121
+ " # Prepare image for model\n",
122
+ " image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)\n",
123
+ "\n",
124
+ " # Generate model response\n",
125
+ " generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)\n",
126
+ " generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)\n",
127
+ "\n",
128
+ " return generated_texts[0]\n",
129
+ "\n",
130
+ "# Gradio Interface\n",
131
+ "def gradio_interface(image, question):\n",
132
+ " response = process_image(image.name, question)\n",
133
+ " return response\n",
134
+ "\n",
135
+ "# Gradio App\n",
136
+ "gr.Interface(\n",
137
+ " fn=gradio_interface,\n",
138
+ " inputs=[\n",
139
+ " gr.File(label=\"Upload .npy Image\", type=\"filepath\"), # For uploading .npy image\n",
140
+ " gr.Textbox(label=\"Enter your question\", placeholder=\"Ask something about the image...\"),\n",
141
+ " ],\n",
142
+ " outputs=gr.Textbox(label=\"Model Response\"),\n",
143
+ " title=\"Medical Image Analysis\",\n",
144
+ " description=\"Upload a .npy image and ask a question to analyze it using the model.\"\n",
145
+ ").launch()\n"
146
+ ]
147
+ }
148
+ ],
149
+ "metadata": {
150
+ "kernelspec": {
151
+ "display_name": "llama",
152
+ "language": "python",
153
+ "name": "python3"
154
+ },
155
+ "language_info": {
156
+ "codemirror_mode": {
157
+ "name": "ipython",
158
+ "version": 3
159
+ },
160
+ "file_extension": ".py",
161
+ "mimetype": "text/x-python",
162
+ "name": "python",
163
+ "nbconvert_exporter": "python",
164
+ "pygments_lexer": "ipython3",
165
+ "version": "3.10.15"
166
+ }
167
+ },
168
+ "nbformat": 4,
169
+ "nbformat_minor": 2
170
+ }