"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": []
+ },
+ "execution_count": 52,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import gradio as gr\n",
+ "\n",
+ "# Function to read CSS from file (improved readability)\n",
+ "def read_css_from_file(filename):\n",
+ " with open(filename, \"r\") as f:\n",
+ " return f.read()\n",
+ "\n",
+ "# Read CSS from file\n",
+ "css = read_css_from_file(\"style.css\")\n",
+ "\n",
+ "# The welcome message with improved styling (see style.css)\n",
+ "welcome_message = '''\n",
+ "\n",
+ "
\n",
+ " AI Medical Chatbot\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " Ask any medical question and get answers from our AI Medical Chatbot\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " Developed by Ruslan Magana. Visit https://ruslanmv.com/ for more information.\n",
+ " \n",
+ "
\n",
+ "'''\n",
+ "\n",
+ "# Creating Gradio interface with full-screen styling\n",
+ "with gr.Blocks(css=css) as interface:\n",
+ " gr.Markdown(welcome_message) # Display the welcome message\n",
+ "\n",
+ " # Input and output elements\n",
+ " with gr.Row():\n",
+ " with gr.Column():\n",
+ " text_prompt = gr.Textbox(label=\"Input Prompt\", placeholder=\"Example: What are the symptoms of COVID-19?\", lines=2)\n",
+ " generate_button = gr.Button(\"Ask Me\", variant=\"primary\")\n",
+ "\n",
+ " with gr.Row():\n",
+ " answer_output = gr.Textbox(type=\"text\", label=\"Answer\")\n",
+ "\n",
+ " # Assuming you have a function `chat` that processes the prompt and returns a response\n",
+ " generate_button.click(chat_v1, inputs=[text_prompt], outputs=answer_output)\n",
+ "\n",
+ "# Launch the app\n",
+ "interface.launch(inline=True, share=False) #For the notebook\n",
+ "#interface.launch(server_name=\"0.0.0.0\",server_port=7860)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/ai-medical-chatbot-master/5-HuggingFace/notebook/local/img/cover.jpg b/ai-medical-chatbot-master/5-HuggingFace/notebook/local/img/cover.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c308c75a94489361965473367f78ec3429b30eba
Binary files /dev/null and b/ai-medical-chatbot-master/5-HuggingFace/notebook/local/img/cover.jpg differ
diff --git a/ai-medical-chatbot-master/5-HuggingFace/notebook/local/requirements.txt b/ai-medical-chatbot-master/5-HuggingFace/notebook/local/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e489578a3ea9651cbce8945e19c249b998b9d9bb
--- /dev/null
+++ b/ai-medical-chatbot-master/5-HuggingFace/notebook/local/requirements.txt
@@ -0,0 +1,401 @@
+absl-py==1.4.0
+accelerate==0.22.0
+aiofiles==23.1.0
+aiohttp==3.8.5
+aioredis==1.3.1
+aiosignal==1.3.1
+altair==5.1.1
+amqp==5.1.1
+annoy==1.17.3
+anyio @ file:///C:/ci/anyio_1644481856696/work/dist
+appdirs==1.4.4
+argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work
+argon2-cffi-bindings @ file:///C:/ci/argon2-cffi-bindings_1644569876605/work
+arrow==1.2.3
+arxiv==1.4.8
+asgiref==3.7.2
+astroid==2.6.6
+asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
+astunparse==1.6.3
+async-lru==2.0.4
+async-timeout==4.0.3
+atomicwrites==1.4.1
+attrs==23.1.0
+audioread==3.0.1
+auto-gptq @ https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+cu117-cp310-cp310-win_amd64.whl#sha256=7145db94f57db80d1d292880487870686079d1b83ef48d3043b9b01023301fa4
+autobahn==23.6.2
+Automat==22.10.0
+azure-core==1.30.1
+azure-storage-blob==12.19.1
+Babel==2.12.1
+backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
+backoff==2.2.1
+beautifulsoup4 @ file:///C:/b/abs_0agyz1wsr4/croot/beautifulsoup4-split_1681493048687/work
+billiard==3.6.4.0
+bitsandbytes @ https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl#sha256=adacda9b2b65dcb1931b222dffd7a91f0b611b3410d5b51c37ef7b22654106e6
+bleach @ file:///opt/conda/conda-bld/bleach_1641577558959/work
+blinker==1.7.0
+boto3==1.34.29
+botocore==1.34.29
+build==0.10.0
+CacheControl==0.13.1
+cachetools==5.3.1
+celery==5.0.5
+certifi==2022.12.7
+cffi @ file:///C:/b/abs_49n3v2hyhr/croot/cffi_1670423218144/work
+channels==3.0.5
+channels-redis==3.2.0
+charset-normalizer==2.1.1
+chromadb==0.3.26
+cleo==2.0.1
+click==8.1.7
+click-didyoumean==0.3.0
+click-plugins==1.1.1
+click-repl==0.3.0
+clickhouse-connect==0.7.0
+colorama @ file:///C:/b/abs_a9ozq0l032/croot/colorama_1672387194846/work
+coloredlogs==15.0.1
+comm==0.1.4
+constantly==15.1.0
+contourpy==1.1.0
+crashtest==0.4.1
+cryptography==41.0.4
+ctransformers @ https://github.com/jllllll/ctransformers-cuBLAS-wheels/releases/download/AVX2/ctransformers-0.2.25+cu117-py3-none-any.whl#sha256=e22c2c47640e30cbac4a779ad7624a47b89ed27d0daefa7b4cb79ea955424207
+cycler==0.11.0
+daphne==3.0.2
+dataclasses-json==0.6.3
+datasets==2.14.5
+debugpy @ file:///C:/b/abs_c0y1fjipt2/croot/debugpy_1690906864587/work
+decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
+defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
+dill==0.3.7
+diskcache==5.6.3
+distlib==0.3.7
+distro==1.9.0
+Django==4.1.11
+django-extensions==3.0.9
+docker-pycreds==0.4.0
+docutils==0.20.1
+duckdb==0.9.2
+dulwich==0.21.6
+einops==0.6.1
+emoji==2.10.1
+entrypoints @ file:///C:/ci/entrypoints_1649926676279/work
+environs==9.5.0
+exceptiongroup @ file:///C:/b/abs_25wqfvkf25/croot/exceptiongroup_1668714345637/work
+executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
+exllama @ https://github.com/jllllll/exllama/releases/download/0.0.17/exllama-0.0.17+cu117-cp310-cp310-win_amd64.whl#sha256=64eff5fefde42b113c64e346c062e50ace5a648257053e889fd618026928b84f
+faiss==1.7.4
+faiss-cpu==1.7.3
+fashion-clip==0.2.2
+fastapi==0.95.2
+fastjsonschema @ file:///C:/Users/BUILDE~1/AppData/Local/Temp/abs_ebruxzvd08/croots/recipe/python-fastjsonschema_1661376484940/work
+feedparser==6.0.10
+ffmpy==0.3.1
+filelock==3.12.4
+Flask==3.0.0
+flatbuffers==23.5.26
+flickrapi==2.4.0
+fonttools==4.42.1
+fqdn==1.5.1
+frozenlist==1.4.0
+fsspec==2023.6.0
+gast==0.4.0
+gdown==4.7.1
+gitdb==4.0.10
+GitPython==3.1.35
+google-auth==2.22.0
+google-auth-oauthlib==1.0.0
+google-pasta==0.2.0
+gptq-for-llama @ https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-win_amd64.whl#sha256=93e632ce0f29ac0b6ae84631b915df1b5d787fcd7dc961cd364edd9a8367b690
+gradio==3.33.1
+gradio_client==0.2.5
+greenlet==2.0.2
+grpcio==1.58.0
+h11==0.14.0
+h5py==3.9.0
+hiredis==2.2.3
+hnswlib==0.8.0
+httpcore==0.17.3
+httptools==0.6.1
+httpx==0.24.1
+huggingface-hub==0.16.4
+humanfriendly==10.0
+hyperlink==21.0.0
+ibm-cos-sdk==2.13.3
+ibm-cos-sdk-core==2.13.3
+ibm-cos-sdk-s3transfer==2.13.3
+ibm-watson-machine-learning==1.0.344
+idna @ file:///C:/b/abs_bdhbebrioa/croot/idna_1666125572046/work
+importlib-metadata==6.8.0
+importlib-resources==6.0.1
+incremental==22.10.0
+iniconfig==2.0.0
+installer==0.7.0
+ipykernel @ file:///C:/b/abs_07rkft_vaz/croot/ipykernel_1691121700587/work
+ipyplot==1.1.2
+ipython==8.16.1
+ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
+ipywidgets==8.1.1
+isodate==0.6.1
+isoduration==20.11.0
+isort==5.12.0
+itsdangerous==2.1.2
+jaraco.classes==3.3.0
+jedi @ file:///C:/ci/jedi_1644315428305/work
+Jinja2 @ file:///C:/b/abs_7cdis66kl9/croot/jinja2_1666908141852/work
+jmespath==1.0.1
+joblib==1.3.2
+json5==0.9.14
+jsonpatch==1.33
+jsonpointer==2.4
+jsonschema==4.19.1
+jsonschema-specifications==2023.7.1
+jupyter-events==0.7.0
+jupyter-lsp==2.2.0
+jupyter_client @ file:///C:/b/abs_d8fk_kz9zk/croot/jupyter_client_1676330195659/work
+jupyter_core @ file:///C:/b/abs_9d0ttho3bs/croot/jupyter_core_1679906581955/work
+jupyter_server==2.7.3
+jupyter_server_terminals==0.4.4
+jupyterlab==4.0.6
+jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
+jupyterlab-widgets==3.0.9
+jupyterlab_server==2.25.0
+keras==2.13.1
+keyring==24.2.0
+kiwisolver==1.4.5
+kombu==5.3.2
+langchain==0.0.345
+langchain-community==0.0.15
+langchain-core==0.0.13
+langchain-openai==0.0.4
+langid==1.1.6
+langsmith==0.0.83
+lazy-object-proxy==1.9.0
+lazy_loader==0.3
+libclang==16.0.6
+librosa==0.10.1
+linkify-it-py==2.0.2
+llama-cpp-python @ https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.84/llama_cpp_python-0.1.84-cp310-cp310-win_amd64.whl#sha256=be549a1e26156af0e355153e7744cb17d7462991430997fb008c7473a0f181bf
+llama-cpp-python-cuda @ https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.1.84+cu117-cp310-cp310-win_amd64.whl#sha256=ea2dac857d79edf380eddc2a7c4eb3446a0da5cd623298926e845566497337fd
+llama-cpp-python-ggml @ https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python_ggml-0.1.78+cpuavx2-cp310-cp310-win_amd64.whl#sha256=6c0cb266a3c22d3a170efb2f19d6c63907efa82288d436e5127daf9ab54c6f9c
+llama-cpp-python-ggml-cuda @ https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_ggml_cuda-0.1.78+cu117-cp310-cp310-win_amd64.whl#sha256=04ca481d43a5b28c45959a6edad2126699461f99607417c7421625738901c112
+llvmlite==0.41.1
+loguru==0.7.2
+lomond==0.3.3
+lxml @ file:///C:/b/abs_c2bg6ck92l/croot/lxml_1679646459966/work
+lz4==4.3.3
+Markdown==3.4.4
+markdown-it-py==2.2.0
+MarkupSafe @ file:///C:/ci/markupsafe_1654508036328/work
+marshmallow==3.20.2
+matplotlib==3.7.2
+matplotlib-inline @ file:///C:/ci/matplotlib-inline_1661934094726/work
+mccabe==0.6.1
+mdit-py-plugins==0.3.3
+mdurl==0.1.2
+mediafire==0.6.1
+minio==7.2.5
+mistune @ file:///C:/ci_310/mistune_1642084168466/work
+mkl-fft==1.3.1
+mkl-random @ file:///C:/ci_310/mkl_random_1643050563308/work
+mkl-service==2.4.0
+monotonic==1.6
+more-itertools==10.1.0
+mpmath==1.2.1
+msgpack==1.0.7
+multidict==6.0.4
+multiprocess==0.70.15
+mypy-extensions==1.0.0
+nbclassic @ file:///C:/b/abs_c8_rs7b3zw/croot/nbclassic_1681756186106/work
+nbclient @ file:///C:/ci/nbclient_1650308592199/work
+nbconvert @ file:///C:/b/abs_4av3q4okro/croot/nbconvert_1668450658054/work
+nbformat @ file:///C:/b/abs_85_3g7dkt4/croot/nbformat_1670352343720/work
+nest-asyncio @ file:///C:/b/abs_3a_4jsjlqu/croot/nest-asyncio_1672387322800/work
+networkx==3.0
+nh3==0.2.15
+nltk==3.8.1
+noisereduce==3.0.0
+notebook @ file:///C:/b/abs_e2qn6c85jb/croot/notebook_1690985290943/work
+notebook_shim @ file:///C:/b/abs_ebfczttg6x/croot/notebook-shim_1668160590914/work
+numba==0.58.1
+numpy==1.24.0
+oauthlib==3.2.2
+onnxruntime==1.16.3
+openai==1.10.0
+opt-einsum==3.3.0
+optimum==1.12.0
+orjson==3.9.7
+overrides==7.4.0
+packaging==23.2
+pandas==1.5.3
+pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work
+parso @ file:///opt/conda/conda-bld/parso_1641458642106/work
+pathtools==0.1.2
+peft==0.5.0
+pexpect==4.8.0
+pgvector==0.2.4
+pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
+Pillow==10.0.0
+pkginfo==1.9.6
+platformdirs @ file:///C:/b/abs_b6z_yqw_ii/croot/platformdirs_1692205479426/work
+pluggy==1.3.0
+poetry==1.6.1
+poetry-core==1.7.0
+poetry-plugin-export==1.5.0
+pooch==1.8.1
+posthog==3.3.3
+pq==1.9.1
+prometheus-client @ file:///C:/Windows/TEMP/abs_ab9nx8qb08/croots/recipe/prometheus_client_1659455104602/work
+prompt-toolkit @ file:///C:/b/abs_6coz5_9f2s/croot/prompt-toolkit_1672387908312/work
+protobuf==4.24.3
+psutil @ file:///C:/Windows/Temp/abs_b2c2fd7f-9fd5-4756-95ea-8aed74d0039flsd9qufz/croots/recipe/psutil_1656431277748/work
+psycopg==3.1.17
+psycopg-binary==3.1.17
+psycopg2==2.9.9
+psycopg2-binary==2.9.9
+ptyprocess==0.7.0
+pulsar-client==3.4.0
+pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
+py==1.11.0
+py-cpuinfo==9.0.0
+pyarrow==13.0.0
+pyasn1==0.5.0
+pyasn1-modules==0.3.0
+pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
+pycryptodome==3.20.0
+pydantic==1.10.0
+pydub==0.25.1
+Pygments @ file:///C:/b/abs_fay9dpq4n_/croot/pygments_1684279990574/work
+pylint==2.7.4
+pymilvus==2.4.0
+PyMuPDF==1.23.3
+PyMuPDFb==1.23.3
+pynndescent==0.5.11
+pyOpenSSL==23.2.0
+pyparsing==3.0.9
+pypdf==4.0.0
+pyproject_hooks==1.0.0
+pyreadline3==3.4.1
+pyrsistent @ file:///C:/ci_310/pyrsistent_1642117077485/work
+PySocks==1.7.1
+pytest==6.2.5
+pytest-django==4.1.0
+pytest-mock==3.3.1
+python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
+python-dotenv==1.0.1
+python-json-logger==2.0.7
+python-multipart==0.0.6
+pytz==2023.3.post1
+pywin32==305.1
+pywin32-ctypes==0.2.2
+pywinpty @ file:///C:/b/abs_73vshmevwq/croot/pywinpty_1677609966356/work/target/wheels/pywinpty-2.0.10-cp310-none-win_amd64.whl
+PyYAML==6.0.1
+pyzmq==25.1.1
+rag==0.1.0
+ragclip==0.0.4
+rapidfuzz==2.15.1
+readme_renderer==43.0
+redis==3.5.3
+referencing==0.30.2
+regex==2023.8.8
+requests==2.31.0
+requests-oauthlib==1.3.1
+requests-toolbelt==1.0.0
+rfc3339-validator==0.1.4
+rfc3986==2.0.0
+rfc3986-validator==0.1.1
+rich==13.7.1
+rouge==1.0.1
+rpds-py==0.10.2
+rsa==4.9
+s3transfer==0.10.0
+safetensors==0.3.2
+scidownl==1.0.2
+scikit-learn==1.3.0
+scipy==1.11.1
+semantic-version==2.10.0
+Send2Trash==1.8.2
+sentence-transformers==2.2.2
+sentencepiece==0.1.99
+sentry-sdk==1.30.0
+service-identity==23.1.0
+setproctitle==1.3.2
+sgmllib3k==1.0.0
+shellingham==1.5.3
+shortuuid==1.0.11
+six @ file:///tmp/build/80754af9/six_1644875935023/work
+sklearn==0.0.post9
+smmap==5.0.0
+sniffio @ file:///C:/ci_310/sniffio_1642092172680/work
+soundfile==0.12.1
+soupsieve @ file:///C:/b/abs_a989exj3q6/croot/soupsieve_1680518492466/work
+soxr==0.3.7
+SQLAlchemy==2.0.20
+sqlparse==0.4.4
+stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
+starlette==0.27.0
+sympy==1.11.1
+tablib==3.5.0
+tabulate==0.9.0
+tenacity==8.2.3
+tensorboard==2.13.0
+tensorboard-data-server==0.7.1
+tensorflow==2.13.0
+tensorflow-estimator==2.13.0
+tensorflow-hub==0.14.0
+tensorflow-intel==2.13.0
+tensorflow-io-gcs-filesystem==0.31.0
+termcolor==2.3.0
+terminado @ file:///C:/b/abs_25nakickad/croot/terminado_1671751845491/work
+threadpoolctl==3.2.0
+tiktoken==0.5.2
+tinycss2 @ file:///C:/b/abs_52w5vfuaax/croot/tinycss2_1668168823131/work
+tokenizers==0.13.3
+toml==0.10.2
+tomli==2.0.1
+tomlkit==0.12.1
+toolz==0.12.0
+torch==2.0.1+cu117
+torchaudio==2.0.2+cu117
+torchvision==0.15.2+cu117
+tornado @ file:///C:/b/abs_61jhmrrua1/croot/tornado_1690848767317/work
+towhee==1.1.3
+tqdm==4.66.1
+traitlets @ file:///C:/b/abs_e5m_xjjl94/croot/traitlets_1671143896266/work
+transformers==4.33.1
+trove-classifiers==2023.9.19
+twine==5.0.0
+Twisted==23.8.0
+twisted-iocpsupport==1.0.4
+txaio==23.1.1
+typing-inspect==0.9.0
+typing_extensions @ file:///C:/b/abs_213vg2cd1l/croot/typing_extensions_1690297804941/work
+tzdata==2023.3
+uc-micro-py==1.0.2
+ujson==5.9.0
+umap-learn==0.3.10
+uri-template==1.3.0
+urllib3==1.26.18
+uvicorn==0.23.2
+validators==0.22.0
+vine==5.0.0
+virtualenv==20.24.5
+wandb==0.15.10
+watchdog==0.10.4
+watchfiles==0.21.0
+wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
+webcolors==1.13
+webencodings==0.5.1
+websocket-client @ file:///C:/ci_310/websocket-client_1642093970919/work
+websockets==11.0.3
+Werkzeug==3.0.1
+wget==3.2
+widgetsnbextension==4.0.9
+win32-setctime==1.1.0
+wrapt==1.12.1
+xxhash==3.3.0
+yarl==1.9.2
+zipp==3.17.0
+zope.interface==6.0
+zstandard==0.22.0
diff --git a/ai-medical-chatbot-master/5-HuggingFace/notebook/local/style.css b/ai-medical-chatbot-master/5-HuggingFace/notebook/local/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..acb3595699672b7fcac37c0598c4c8fbc7098398
--- /dev/null
+++ b/ai-medical-chatbot-master/5-HuggingFace/notebook/local/style.css
@@ -0,0 +1,73 @@
+/* General Container Styles */
+.gradio-container {
+ font-family: "IBM Plex Sans", sans-serif;
+ position: fixed; /* Ensure full-screen coverage */
+ top: 0;
+ left: 0;
+ width: 100vw; /* Set width to 100% viewport width */
+ height: 100vh; /* Set height to 100% viewport height */
+ margin: 0; /* Remove margins for full-screen effect */
+ padding: 0; /* Remove padding for full-screen background */
+ background-color: #212529; /* Dark background color */
+ color: #fff; /* Light text color for better readability */
+ overflow: hidden; /* Hide potential overflow content */
+ background-image: url("https://raw.githubusercontent.com/ruslanmv/ai-medical-chatbot/master/assets/images/background.jpg"); /* Replace with your image path */
+ background-size: cover; /* Stretch the image to cover the container */
+ background-position: center; /* Center the image horizontally and vertically */
+}
+/* Button Styles */
+.gr-button {
+ color: white;
+ background: #007bff; /* Use a primary color for the background */
+ white-space: nowrap;
+ border: none;
+ padding: 10px 20px;
+ border-radius: 8px;
+ cursor: pointer;
+ transition: background-color 0.3s, color 0.3s;
+}
+.gr-button:hover {
+ background-color: #0056b3; /* Darken the background color on hover */
+}
+
+/* Share Button Styles (omitted as not directly affecting dark mode) */
+/* ... */
+
+/* Other styles (adjustments for full-screen might be needed) */
+#gallery {
+ min-height: 22rem;
+ /* Center the gallery horizontally (optional) */
+ margin: auto;
+ border-bottom-right-radius: 0.5rem !important;
+ border-bottom-left-radius: 0.5rem !important;
+ background-color: #212529; /* Dark background color for elements */
+}
+
+/* Centered Container for the Image */
+.image-container {
+ max-width: 100%; /* Set the maximum width for the container */
+ margin: auto; /* Center the container horizontally */
+ padding: 20px; /* Add padding for spacing */
+ border: 1px solid #a50909; /* Add a subtle border to the container */
+ border-radius: 10px;
+ overflow: hidden; /* Hide overflow if the image is larger */
+ max-height: 22rem; /* Set a maximum height for the container */
+ background-color: #212529; /* Dark background color for elements */
+}
+
+/* Set a fixed size for the image */
+.image-container img {
+ max-width: 100%; /* Ensure the image fills the container */
+ height: auto; /* Maintain aspect ratio */
+ max-height: 100%;
+ border-radius: 10px;
+ box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.2);
+}
+
+/* Output box styles */
+.gradio-textbox {
+ background-color: #343a40; /* Dark background color */
+ color: #fff; /* Light text color for better readability */
+ border-color: #343a40; /* Dark border color */
+ border-radius: 8px;
+}
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/5-HuggingFace/notebook/watsonx/chatbot.ipynb b/ai-medical-chatbot-master/5-HuggingFace/notebook/watsonx/chatbot.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..54d3898769cd604e3bb67d2a92f7ffdb5e5ec4e9
--- /dev/null
+++ b/ai-medical-chatbot-master/5-HuggingFace/notebook/watsonx/chatbot.ipynb
@@ -0,0 +1,208 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "from IPython.display import clear_output\n",
+ "import pandas as pd\n",
+ "import re\n",
+ "from dotenv import load_dotenv\n",
+ "import os\n",
+ "from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes\n",
+ "from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams\n",
+ "from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods\n",
+ "from langchain.llms import WatsonxLLM\n",
+ "from langchain.embeddings import SentenceTransformerEmbeddings\n",
+ "from langchain.embeddings.base import Embeddings\n",
+ "from langchain.vectorstores.milvus import Milvus\n",
+ "from langchain.embeddings import HuggingFaceEmbeddings # Not used in this example\n",
+ "from dotenv import load_dotenv\n",
+ "import os\n",
+ "from pymilvus import Collection, utility\n",
+ "from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility\n",
+ "from towhee import pipe, ops\n",
+ "import numpy as np\n",
+ "#import langchain.chains as lc\n",
+ "from langchain_core.retrievers import BaseRetriever\n",
+ "from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
+ "from langchain_core.documents import Document\n",
+ "from pymilvus import Collection, utility\n",
+ "from towhee import pipe, ops\n",
+ "import numpy as np\n",
+ "from towhee.datacollection import DataCollection\n",
+ "from typing import List\n",
+ "from langchain.chains import RetrievalQA\n",
+ "from langchain.prompts import PromptTemplate\n",
+ "from langchain.schema.runnable import RunnablePassthrough\n",
+ "from langchain_core.retrievers import BaseRetriever\n",
+ "from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
+ "\n",
+ "print_full_prompt=False\n",
+ "\n",
+ "## Step 1 Dataset Retrieving\n",
+ "\n",
+ "dataset = load_dataset(\"ruslanmv/ai-medical-chatbot\")\n",
+ "clear_output()\n",
+ "train_data = dataset[\"train\"]\n",
+ "#For this demo let us choose the first 1000 dialogues\n",
+ "\n",
+ "df = pd.DataFrame(train_data[:1000])\n",
+ "#df = df[[\"Patient\", \"Doctor\"]].rename(columns={\"Patient\": \"question\", \"Doctor\": \"answer\"})\n",
+ "df = df[[\"Description\", \"Doctor\"]].rename(columns={\"Description\": \"question\", \"Doctor\": \"answer\"})\n",
+ "# Add the 'ID' column as the first column\n",
+ "df.insert(0, 'id', df.index)\n",
+ "# Reset the index and drop the previous index column\n",
+ "df = df.reset_index(drop=True)\n",
+ "\n",
+ "# Clean the 'question' and 'answer' columns\n",
+ "df['question'] = df['question'].apply(lambda x: re.sub(r'\\s+', ' ', x.strip()))\n",
+ "df['answer'] = df['answer'].apply(lambda x: re.sub(r'\\s+', ' ', x.strip()))\n",
+ "df['question'] = df['question'].str.replace('^Q.', '', regex=True)\n",
+ "# Assuming your DataFrame is named df\n",
+ "max_length = 500 # Due to our enbeeding model does not allow long strings\n",
+ "df['question'] = df['question'].str.slice(0, max_length)\n",
+ "#To use the dataset to get answers, let's first define the dictionary:\n",
+ "#- `id_answer`: a dictionary of id and corresponding answer\n",
+ "id_answer = df.set_index('id')['answer'].to_dict()\n",
+ "\n",
+ "## Step 2 WatsonX connection\n",
+ "\n",
+ "load_dotenv()\n",
+ "try:\n",
+ " API_KEY = os.environ.get(\"API_KEY\")\n",
+ " project_id =os.environ.get(\"PROJECT_ID\")\n",
+ "except KeyError:\n",
+ " API_KEY: input(\"Please enter your WML api key (hit enter): \")\n",
+ " project_id = input(\"Please project_id (hit enter): \")\n",
+ "\n",
+ "credentials = {\n",
+ " \"url\": \"https://us-south.ml.cloud.ibm.com\",\n",
+ " \"apikey\": API_KEY \n",
+ "} \n",
+ "\n",
+ "model_id = ModelTypes.GRANITE_13B_CHAT_V2\n",
+ "\n",
+ "\n",
+ "parameters = {\n",
+ " GenParams.DECODING_METHOD: DecodingMethods.GREEDY,\n",
+ " GenParams.MIN_NEW_TOKENS: 1,\n",
+ " GenParams.MAX_NEW_TOKENS: 500,\n",
+ " GenParams.STOP_SEQUENCES: [\"<|endoftext|>\"]\n",
+ "}\n",
+ "\n",
+ "\n",
+ "watsonx_granite = WatsonxLLM(\n",
+ " model_id=model_id.value,\n",
+ " url=credentials.get(\"url\"),\n",
+ " apikey=credentials.get(\"apikey\"),\n",
+ " project_id=project_id,\n",
+ " params=parameters\n",
+ ")\n",
+ "\n",
+ "\n",
+ "## Step 3 Milvus connection\n",
+ "\n",
+ "COLLECTION_NAME='qa_medical'\n",
+ "load_dotenv()\n",
+ "host_milvus = os.environ.get(\"REMOTE_SERVER\", '127.0.0.1')\n",
+ "connections.connect(host=host_milvus, port='19530')\n",
+ "\n",
+ "\n",
+ "collection = Collection(COLLECTION_NAME) \n",
+ "collection.load(replica_number=1)\n",
+ "utility.load_state(COLLECTION_NAME)\n",
+ "utility.loading_progress(COLLECTION_NAME)\n",
+ "\n",
+ "\n",
+ "max_input_length = 500 # Maximum length allowed by the model\n",
+ "\n",
+ "\n",
+ "\n",
+ "# Create the combined pipe for question encoding and answer retrieval\n",
+ "combined_pipe = (\n",
+ " pipe.input('question')\n",
+ " .map('question', 'vec', lambda x: x[:max_input_length]) # Truncate the question if longer than 512 tokens\n",
+ " .map('vec', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))\n",
+ " .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))\n",
+ " .map('vec', 'res', ops.ann_search.milvus_client(host=host_milvus, port='19530', collection_name=COLLECTION_NAME, limit=1))\n",
+ " .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])\n",
+ " .output('question', 'answer')\n",
+ ")\n",
+ " \n",
+ "# Step 4 Langchain Definitions\n",
+ "\n",
+ "class CustomRetrieverLang(BaseRetriever): \n",
+ " def get_relevant_documents(\n",
+ " self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n",
+ " ) -> List[Document]:\n",
+ " # Perform the encoding and retrieval for a specific question\n",
+ " ans = combined_pipe(query)\n",
+ " ans = DataCollection(ans)\n",
+ " answer=ans[0]['answer']\n",
+ " answer_string = ' '.join(answer)\n",
+ " return [Document(page_content=answer_string)] \n",
+ "# Ensure correct VectorStoreRetriever usage\n",
+ "retriever = CustomRetrieverLang()\n",
+ "\n",
+ "# Define the prompt template\n",
+ "template = \"\"\"Use the following pieces of context to answer the question at the end. \n",
+ "If you don't know the answer, just say that you don't know, don't try to make up an answer. \n",
+ "Use three sentences maximum and keep the answer as concise as possible. \n",
+ "Always say \"thanks for asking!\" at the end of the answer. \n",
+ "{context}\n",
+ "Question: {question}\n",
+ "Helpful Answer:\"\"\"\n",
+ "rag_prompt = PromptTemplate.from_template(template)\n",
+ "rag_chain = (\n",
+ " {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
+ " | rag_prompt\n",
+ " | watsonx_granite\n",
+ ")\n",
+ "\n",
+ "prompt = \"I have started to get lots of acne on my face, particularly on my forehead what can I do\"\n",
+ "\n",
+ "if print_full_prompt:\n",
+ " # Get the retrieved context\n",
+ " context = retriever.get_relevant_documents(prompt)\n",
+ " print(\"Retrieved context:\")\n",
+ " for doc in context:\n",
+ " print(doc)\n",
+ " # Construct the full prompt\n",
+ " full_prompt = rag_prompt.format(context=context, question=prompt)\n",
+ " print(\"Full prompt:\", full_prompt)\n",
+ "\n",
+ "print(rag_chain.invoke(prompt)) \n",
+ "\n",
+ "import towhee\n",
+ "def chat(message, history):\n",
+ " history = history or []\n",
+ " response = rag_chain.invoke(message)\n",
+ " history.append((message, response))\n",
+ " return history, history\n",
+ "\n",
+ "import gradio\n",
+ "collection.load()\n",
+ "chatbot = gradio.Chatbot()\n",
+ "interface = gradio.Interface(\n",
+ " chat,\n",
+ " [\"text\", \"state\"],\n",
+ " [chatbot, \"state\"],\n",
+ " allow_flagging=\"never\",\n",
+ ")\n",
+ "#interface.launch(inline=True, share=False) #For the notebook\n",
+ "interface.launch(server_name=\"0.0.0.0\",server_port=7860)"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/ai-medical-chatbot-master/5-HuggingFace/requirements.txt b/ai-medical-chatbot-master/5-HuggingFace/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a0c6bbdc70d1ec4f233b6a9f537817610ca7541a
--- /dev/null
+++ b/ai-medical-chatbot-master/5-HuggingFace/requirements.txt
@@ -0,0 +1,193 @@
+aiofiles==23.2.1
+aiohttp==3.9.3
+aiosignal==1.3.1
+altair==5.2.0
+annotated-types==0.6.0
+anyio==3.7.1
+argon2-cffi==23.1.0
+argon2-cffi-bindings==21.2.0
+asttokens==2.4.1
+async-timeout==4.0.3
+attrs==23.2.0
+backoff==2.2.1
+beautifulsoup4==4.12.3
+bs4==0.0.2
+certifi==2024.2.2
+cffi==1.16.0
+charset-normalizer==3.3.2
+chromadb==0.3.22
+click==8.1.7
+clickhouse-connect==0.7.0
+comm==0.2.1
+contourpy==1.2.0
+cryptography==42.0.3
+cycler==0.12.1
+dataclasses-json==0.6.4
+datasets==2.17.1
+debugpy==1.8.1
+decorator==5.1.1
+dill==0.3.8
+docutils==0.20.1
+duckdb==0.10.0
+environs==9.5.0
+exceptiongroup==1.2.0
+executing==2.0.1
+fastapi==0.109.2
+ffmpy==0.3.2
+filelock==3.13.1
+fonttools==4.49.0
+frozenlist==1.4.1
+fsspec==2023.10.0
+gradio==3.50.2
+gradio_client==0.6.1
+greenlet==3.0.3
+grpcio==1.60.0
+h11==0.14.0
+hnswlib==0.8.0
+httpcore==1.0.3
+httptools==0.6.1
+httpx==0.26.0
+huggingface-hub==0.20.3
+ibm-cos-sdk==2.13.4
+ibm-cos-sdk-core==2.13.4
+ibm-cos-sdk-s3transfer==2.13.4
+ibm-watson-machine-learning==1.0.347
+idna==3.6
+importlib-metadata==7.0.1
+importlib-resources==6.1.1
+ipykernel==6.29.2
+ipython==8.21.0
+ipywidgets==8.1.2
+jaraco.classes==3.3.1
+jedi==0.19.1
+jeepney==0.8.0
+Jinja2==3.1.3
+jmespath==1.0.1
+joblib==1.3.2
+jsonpatch==1.33
+jsonpointer==2.4
+jsonschema==4.21.1
+jsonschema-specifications==2023.12.1
+jupyter_client==8.6.0
+jupyter_core==5.7.1
+jupyterlab_widgets==3.0.10
+keyring==24.3.0
+kiwisolver==1.4.5
+langchain==0.0.345
+langchain-core==0.0.13
+langsmith==0.0.92
+lomond==0.3.3
+lz4==4.3.3
+markdown-it-py==3.0.0
+MarkupSafe==2.1.5
+marshmallow==3.20.2
+matplotlib==3.8.3
+matplotlib-inline==0.1.6
+mdurl==0.1.2
+minio==7.2.4
+monotonic==1.6
+more-itertools==10.2.0
+mpmath==1.3.0
+multidict==6.0.5
+multiprocess==0.70.16
+mypy-extensions==1.0.0
+nest-asyncio==1.6.0
+networkx==3.2.1
+nh3==0.2.15
+nltk==3.8.1
+numpy==1.26.4
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==8.9.2.26
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.19.3
+nvidia-nvjitlink-cu12==12.3.101
+nvidia-nvtx-cu12==12.1.105
+orjson==3.9.14
+packaging==23.2
+pandas==1.5.3
+parso==0.8.3
+pexpect==4.9.0
+pillow==10.2.0
+pkginfo==1.9.6
+platformdirs==4.2.0
+posthog==3.4.1
+prompt-toolkit==3.0.43
+protobuf==4.25.3
+psutil==5.9.8
+ptyprocess==0.7.0
+pure-eval==0.2.2
+pyarrow==15.0.0
+pyarrow-hotfix==0.6
+pycparser==2.21
+pycryptodome==3.20.0
+pydantic==1.10.14
+pydantic_core==2.16.2
+pydub==0.25.1
+Pygments==2.17.2
+pymilvus==2.3.6
+pyparsing==3.1.1
+python-dateutil==2.8.2
+python-dotenv==1.0.1
+python-multipart==0.0.9
+pytz==2024.1
+PyYAML==6.0.1
+pyzmq==25.1.2
+readme-renderer==42.0
+referencing==0.33.0
+regex==2023.12.25
+requests==2.31.0
+requests-toolbelt==1.0.0
+rfc3986==2.0.0
+rich==13.7.0
+rpds-py==0.18.0
+safetensors==0.4.2
+scikit-learn==1.4.1.post1
+scipy==1.12.0
+SecretStorage==3.3.3
+semantic-version==2.10.0
+sentence-transformers==2.3.1
+sentencepiece==0.2.0
+six==1.16.0
+sniffio==1.3.0
+soupsieve==2.5
+SQLAlchemy==2.0.27
+stack-data==0.6.3
+starlette==0.36.3
+sympy==1.12
+tabulate==0.9.0
+tenacity==8.2.3
+threadpoolctl==3.3.0
+tokenizers==0.15.2
+toolz==0.12.1
+torch==2.2.0
+tornado==6.4
+towhee==1.1.3
+towhee.models==1.1.3
+tqdm==4.66.2
+traitlets==5.14.1
+transformers==4.37.2
+triton==2.2.0
+twine==5.0.0
+typing-inspect==0.9.0
+typing_extensions==4.9.0
+tzdata==2024.1
+ujson==5.9.0
+urllib3==2.1.0
+uvicorn==0.27.1
+uvloop==0.19.0
+watchfiles==0.21.0
+wcwidth==0.2.13
+websockets==11.0.3
+wget==3.2
+widgetsnbextension==4.0.10
+xxhash==3.4.1
+yarl==1.9.4
+zipp==3.17.0
+zstandard==0.22.0
+openai==1.10.0
diff --git a/ai-medical-chatbot-master/5-HuggingFace/style.css b/ai-medical-chatbot-master/5-HuggingFace/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..acb3595699672b7fcac37c0598c4c8fbc7098398
--- /dev/null
+++ b/ai-medical-chatbot-master/5-HuggingFace/style.css
@@ -0,0 +1,73 @@
+/* General Container Styles */
+.gradio-container {
+ font-family: "IBM Plex Sans", sans-serif;
+ position: fixed; /* Ensure full-screen coverage */
+ top: 0;
+ left: 0;
+ width: 100vw; /* Set width to 100% viewport width */
+ height: 100vh; /* Set height to 100% viewport height */
+ margin: 0; /* Remove margins for full-screen effect */
+ padding: 0; /* Remove padding for full-screen background */
+ background-color: #212529; /* Dark background color */
+ color: #fff; /* Light text color for better readability */
+ overflow: hidden; /* Hide potential overflow content */
+ background-image: url("https://raw.githubusercontent.com/ruslanmv/ai-medical-chatbot/master/assets/images/background.jpg"); /* Replace with your image path */
+ background-size: cover; /* Stretch the image to cover the container */
+ background-position: center; /* Center the image horizontally and vertically */
+}
+/* Button Styles */
+.gr-button {
+ color: white;
+ background: #007bff; /* Use a primary color for the background */
+ white-space: nowrap;
+ border: none;
+ padding: 10px 20px;
+ border-radius: 8px;
+ cursor: pointer;
+ transition: background-color 0.3s, color 0.3s;
+}
+.gr-button:hover {
+ background-color: #0056b3; /* Darken the background color on hover */
+}
+
+/* Share Button Styles (omitted as not directly affecting dark mode) */
+/* ... */
+
+/* Other styles (adjustments for full-screen might be needed) */
+#gallery {
+ min-height: 22rem;
+ /* Center the gallery horizontally (optional) */
+ margin: auto;
+ border-bottom-right-radius: 0.5rem !important;
+ border-bottom-left-radius: 0.5rem !important;
+ background-color: #212529; /* Dark background color for elements */
+}
+
+/* Centered Container for the Image */
+.image-container {
+ max-width: 100%; /* Set the maximum width for the container */
+ margin: auto; /* Center the container horizontally */
+ padding: 20px; /* Add padding for spacing */
+ border: 1px solid #a50909; /* Add a subtle border to the container */
+ border-radius: 10px;
+ overflow: hidden; /* Hide overflow if the image is larger */
+ max-height: 22rem; /* Set a maximum height for the container */
+ background-color: #212529; /* Dark background color for elements */
+}
+
+/* Set a fixed size for the image */
+.image-container img {
+ max-width: 100%; /* Ensure the image fills the container */
+ height: auto; /* Maintain aspect ratio */
+ max-height: 100%;
+ border-radius: 10px;
+ box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.2);
+}
+
+/* Output box styles */
+.gradio-textbox {
+ background-color: #343a40; /* Dark background color */
+ color: #fff; /* Light text color for better readability */
+ border-color: #343a40; /* Dark border color */
+ border-radius: 8px;
+}
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/6-FineTunning/Fine_Tunning_Medical_Mind_Llama3.ipynb b/ai-medical-chatbot-master/6-FineTunning/Fine_Tunning_Medical_Mind_Llama3.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..7463d204505152a2e8ef04e621acdf9d1ececcce
--- /dev/null
+++ b/ai-medical-chatbot-master/6-FineTunning/Fine_Tunning_Medical_Mind_Llama3.ipynb
@@ -0,0 +1,4925 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ew25zGr2oXpx",
+ "outputId": "22349eb1-d8d4-47f2-ec49-3a558512ec66"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install tqdm\n",
+ "!pip install transformers==4.40.1\n",
+ "!pip install sentencepiece\n",
+ "!pip install datasets\n",
+ "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
+ "!pip install trl\n",
+ "!pip install triton\n",
+ "!pip install bitsandbytes\n",
+ "!pip install --no-deps trl peft accelerate bitsandbytes\n",
+ "!pip install xformers\n",
+ "!pip install pytorch-cuda==12.1 torch xformers\n",
+ "#!pip install --no-deps xformers trl peft accelerate bitsandbytes\n",
+ "#!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
+ "!pip install hyperopt\n",
+ "!pip install optuna"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dH4JvbO9oiHE",
+ "outputId": "399bc210-c095-4807-900f-6b4cf2fe133f"
+ },
+ "outputs": [],
+ "source": [
+ "!python -m xformers.info\n",
+ "!python -m bitsandbytes\n",
+ "!nvidia-smi\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "3sl1jjxFMeVx",
+ "outputId": "f0221fec-8c3f-4fbe-eb8a-e58df86399ce"
+ },
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import torch\n",
+ "from datasets import load_dataset\n",
+ "from huggingface_hub import notebook_login\n",
+ "from transformers import TrainingArguments\n",
+ "from trl import SFTTrainer\n",
+ "from unsloth import FastLanguageModel\n",
+ "print(torch.__version__)\n",
+ "print(torch.version.cuda)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "CIwMK9N7j8Dx"
+ },
+ "outputs": [],
+ "source": [
+ "# Defining the configuration for the base model, LoRA and training\n",
+ "config = {\n",
+ " \"hugging_face_username\":\"ruslanmv\",\n",
+ " \"model_config\": {\n",
+ " \"base_model\":\"meta-llama/Meta-Llama-3-8B-Instruct\", # The base model\n",
+ " \"finetuned_model\":\"ruslanmv/Medical-Mind-Llama-3-8b\", # The finetuned model\n",
+ " \"max_seq_length\": 2048, # The maximum sequence length\n",
+ " # \"dtype\":torch.float16, # The data type\n",
+ " # \"dtype\": torch.float32, # Use float32 instead of half CUDA capability < 8\n",
+ " \"dtype\" : None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
+ "\n",
+ " \"load_in_4bit\": True, # Load the model in 4-bit\n",
+ " },\n",
+ " \"lora_config\": {\n",
+ " \"r\": 16, # The number of LoRA layers 8, 16, 32, 64\n",
+ " \"target_modules\": [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
+ " \"gate_proj\", \"up_proj\", \"down_proj\"], # The target modules\n",
+ " \"lora_alpha\":16, # The alpha value for LoRA\n",
+ " #\"lora_alpha\":15, # The alpha value for LoRA by search grid\n",
+ " \"lora_dropout\":0, # The dropout value for LoRA\n",
+ " \"bias\":\"none\", # The bias for LoRA\n",
+ " \"use_gradient_checkpointing\":True, # Use gradient checkpointing\n",
+ " \"use_rslora\":False, # Use RSLora\n",
+ " \"use_dora\":False, # Use DoRa\n",
+ " \"loftq_config\":None # The LoFTQ configuration\n",
+ " },\n",
+ "\n",
+ " \"training_config\": {\n",
+ " \"per_device_train_batch_size\": 2, # The batch size\n",
+ " #\"per_device_train_batch_size\": 6, # The batch size by search grid\n",
+ " \"gradient_accumulation_steps\": 4, # The gradient accumulation steps\n",
+ " #\"gradient_accumulation_steps\": 7, # The gradient accumulation steps by search grid\n",
+ " \"warmup_steps\": 5, # The warmup steps\n",
+ " \"max_steps\":0, # The maximum steps (0 if the epochs are defined)\n",
+ " \"num_train_epochs\": 1, # The number of training epochs(0 if the maximum steps are defined)\n",
+ " \"learning_rate\": 2e-4, # The learning rate\n",
+ " #\"learning_rate\": 9.5e-05, # The learning rate by search grid\n",
+ " \"fp16\": not torch.cuda.is_bf16_supported(), # The fp16\n",
+ " \"bf16\": torch.cuda.is_bf16_supported(), # The bf16\n",
+ " \"logging_steps\": 1, # The logging steps\n",
+ " \"optim\" :\"adamw_8bit\", # The optimizer\n",
+ " \"weight_decay\" : 0.01, # The weight decay\n",
+ " \"lr_scheduler_type\": \"linear\", # The learning rate scheduler\n",
+ " \"seed\" : 42, # The seed\n",
+ " \"output_dir\" : \"outputs\", # The output directory\n",
+ " }\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "pyN9TpNj8lUQ"
+ },
+ "outputs": [],
+ "source": [
+ "config_dataset={ \"training_dataset\": {\n",
+ " \"name\": \"ruslanmv/ai-medical-dataset\", # The dataset name(huggingface/datasets)\n",
+ " \"split\": \"train\", # The dataset split\n",
+ " \"input_fields\": [\"question\", \"context\"] ,# The input fields\n",
+ " \"input_field\": \"text\",# The input field\n",
+ " },\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ztqrczgo9Zrg",
+ "outputId": "09eaabca-aba6-485a-8bab-6ddd96b077b9"
+ },
+ "outputs": [],
+ "source": [
+ "# Loading the model and the tokinizer for the model\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = config.get(\"model_config\").get(\"base_model\"),\n",
+ " max_seq_length = config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype = config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit = config.get(\"model_config\").get(\"load_in_4bit\"),\n",
+ "\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "hSQMljYD9hCh",
+ "outputId": "7081b216-b5a7-4b36-fe01-91c166d9e491"
+ },
+ "outputs": [],
+ "source": [
+ "# Setup for QLoRA/LoRA peft of the base model\n",
+ "model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r = config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules = config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha = config.get(\"lora_config\").get(\"lora_alpha\"),\n",
+ " lora_dropout = config.get(\"lora_config\").get(\"lora_dropout\"),\n",
+ " bias = config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing = config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state = 42,\n",
+ " use_rslora = config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora = config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config = config.get(\"lora_config\").get(\"loftq_config\"),\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 131,
+ "referenced_widgets": [
+ "5d1fbd3c62d94df7befdefc451221414",
+ "8ad6abb48f38469f9d399eea8f5e5b70",
+ "6cea0da24cf54811a43168c606759bab",
+ "eb8c88f5c06c49fe9099371b3cf112ae",
+ "89a1354722e640758978befc06ed4a78",
+ "39d3b72ab6214bcf9b0bb6b6294e957c",
+ "696e82ec6a174974a90d5abc7c101ee7",
+ "dade882aca304a31b693a2c58807d825",
+ "02fc530028ea4d538b7f6b48463ae700",
+ "00eea4b0c6e44c62900ea8e7d919efe9",
+ "fe17bedb5ef04d8b9e064fa1e0d75185",
+ "bb1156b7d349440d9cc8a2f0328465a7",
+ "23a71f8847e647daba35e495706fc846",
+ "3f7afd4bd28842cbb73e62c155667030",
+ "a419499622cd4374937423a79677298f",
+ "64539b4212fe4d989976f56369bb746b",
+ "22ea45365d21439fb5069974bbe69711",
+ "bd087d0aa3214c5dbecc9b0bd4d976df",
+ "9a5fd3a68fd1445f92bea51a7fec3e6b",
+ "37803098ceed4528bb690ebee028c840",
+ "b93514308ae44afbb1a0511f5f9c6ddf",
+ "58b932a03b2c4aa4891d541f186244b9",
+ "3564e3cf0fe84281838d84525794e735",
+ "912164947c5847908424f3e60c5adb64",
+ "7517ce80636040e29665a9353afab183",
+ "e14b9d980a1a41fb9e81385cb0f73d3a",
+ "ada78aafba3f47ab8eb45cf3c83a6805",
+ "ff108c92fb5547869ee545cf9a094b07",
+ "2c5564fb033346afbe7692a24a52b302",
+ "bb078c8c1f6a48359dc654d91ece684d",
+ "9b9322336b564a409086955ebda07fc3",
+ "9bceb9eddb2147c1abbf3391c70e6784",
+ "8a195771bdc0462e8f9fbb60eb9141b1"
+ ]
+ },
+ "id": "ty1UIoRd9Hlv",
+ "outputId": "59bba8f0-2329-465f-dbe7-b5ee5adf3ee2"
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n",
+ "tokenizer = AutoTokenizer.from_pretrained(config.get(\"model_config\").get(\"base_model\"))\n",
+ "\n",
+ "\n",
+ "tokenizer.add_eos_token = True\n",
+ "tokenizer.pad_token_id = 0\n",
+ "tokenizer.padding_side = \"left\"\n",
+ "\n",
+ "# Loading the training dataset\n",
+ "train_dataset = load_dataset(config_dataset.get(\"training_dataset\").get(\"name\"), split = config_dataset.get(\"training_dataset\").get(\"split\"))\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Vk-n3n_x9wh1"
+ },
+ "outputs": [],
+ "source": [
+ "# Select the first 100 rows of the dataset\n",
+ "test_dataset = train_dataset.select(range(100))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 174,
+ "referenced_widgets": [
+ "e3bd7f85ce194cd4b697c2eb82038658",
+ "734b6d3e3406403293c4bc955a643528",
+ "0005f2d9fe1e4cc98ea58b0c2868b433",
+ "be6162f66e594d3ebd8c53ebab3bbfa6",
+ "7e11cccce8be49008f8db3a0c3ea603d",
+ "dc3b2edc3f5d480a93b57b15b4444608",
+ "7967d420aff1414e9fe53eb04c928eb4",
+ "45c1d5b0df0e420a87f791dd4cf0e425",
+ "9ed49f1a099846a3a65cd6608bafb0e4",
+ "963c0aa5620b4ea8b5a903894646121c",
+ "31a203cdd2f54cda8a05214844888156"
+ ]
+ },
+ "id": "x8U2HpEh-OFi",
+ "outputId": "837b69f8-88f2-48a9-8e11-6178b4a5c269"
+ },
+ "outputs": [],
+ "source": [
+ "medical_prompt = \"\"\"You are an AI Medical Assistant Chatbot, trained to answer medical questions. Below is an instruction that describes a task, paired with an response context. Write a response that appropriately completes the request.\n",
+ "\n",
+ "### Instruction:\n",
+ "{}\n",
+ "\n",
+ "\n",
+ "### Response:\n",
+ "{}\"\"\"\n",
+ "\n",
+ "EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN\n",
+ "def formatting_prompts_func(examples):\n",
+ " instructions = examples[\"question\"]\n",
+ " outputs = examples[\"context\"]\n",
+ " texts = []\n",
+ " for instruction, output in zip(instructions, outputs):\n",
+ " # Must add EOS_TOKEN, otherwise your generation will go on forever!\n",
+ " text = medical_prompt.format(instruction, output) + EOS_TOKEN\n",
+ " texts.append(text)\n",
+ " return { \"text\" : texts, }\n",
+ "pass\n",
+ "\n",
+ "test_dataset= test_dataset.map(formatting_prompts_func, batched = True,)\n",
+ "\n",
+ "\n",
+ "\n",
+ "test_dataset['text'][1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "DKAZ3zhx-TZA",
+ "outputId": "5fc0788a-3d11-4bcf-e502-717e7b3b5b2c"
+ },
+ "outputs": [],
+ "source": [
+ "test_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 140,
+ "referenced_widgets": [
+ "72eca1e2871b458abd3383d9711215a2",
+ "058b2b9959b84b6f9f5d3862ef53d029",
+ "85d4879bd7d64766905db34cef052fed",
+ "44f189b81bbd48ca8cb146ead641d2b5",
+ "f89c5c949e984361bce7f97d86d2a2e5",
+ "7807f312425b4f4d9249aa1ac77d7461",
+ "d8e7ea9552a84b8284b31d77090b54af",
+ "0058ed544fed4272848a891a68b9adc0",
+ "33fb10908c23457aa4796626102fc8c5",
+ "e903140c8c794c48b231924d3975b7a6",
+ "7e74d789c82747e0b5066a00b9e36c1d"
+ ]
+ },
+ "id": "JkMVp2ZplGPA",
+ "outputId": "3c0777d0-e2a1-4a27-f035-615da4495e45"
+ },
+ "outputs": [],
+ "source": [
+ "# Setting up the trainer for the model\n",
+ "trainer_test = SFTTrainer(\n",
+ " model = model,\n",
+ " tokenizer = tokenizer,\n",
+ " train_dataset = test_dataset,\n",
+ " dataset_text_field = config_dataset.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length = config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc = 2,\n",
+ " packing = False,\n",
+ " args = TrainingArguments(\n",
+ " per_device_train_batch_size = config.get(\"training_config\").get(\"per_device_train_batch_size\"),\n",
+ " gradient_accumulation_steps = config.get(\"training_config\").get(\"gradient_accumulation_steps\"),\n",
+ " warmup_steps = config.get(\"training_config\").get(\"warmup_steps\"),\n",
+ " max_steps = config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs= config.get(\"training_config\").get(\"num_train_epochs\"),\n",
+ " learning_rate = config.get(\"training_config\").get(\"learning_rate\"),\n",
+ " fp16 = config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16 = config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps = config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim = config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay = config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type = config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed = 42,\n",
+ " output_dir = config.get(\"training_config\").get(\"output_dir\"),\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZeRzS2N0kADu"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "t00fCPO9zf8x"
+ },
+ "source": [
+ "## Method 1 optuna"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 725
+ },
+ "id": "H_MOQOYBj5jx",
+ "outputId": "699f77e3-8754-4087-dd78-565bca527d08"
+ },
+ "outputs": [],
+ "source": [
+ "from optuna import create_study, Trial\n",
+ "\n",
+ "# Define search space\n",
+ "search_space = {\n",
+ " \"learning_rate\": [1e-5, 5e-5, 1e-4, 2e-4],\n",
+ " \"per_device_train_batch_size\": [2, 4, 8],\n",
+ " \"lora_alpha\": [8, 16, 32],\n",
+ "}\n",
+ "\n",
+ "def objective(trial):\n",
+ " # Set hyperparameters based on trial values\n",
+ " config[\"training_config\"][\"learning_rate\"] = trial.suggest_float(\"learning_rate\", search_space[\"learning_rate\"][0], search_space[\"learning_rate\"][-1])\n",
+ " config[\"training_config\"][\"per_device_train_batch_size\"] = trial.suggest_int(\"per_device_train_batch_size\", search_space[\"per_device_train_batch_size\"][0], search_space[\"per_device_train_batch_size\"][-1])\n",
+ " config[\"lora_config\"][\"lora_alpha\"] = trial.suggest_int(\"lora_alpha\", search_space[\"lora_alpha\"][0], search_space[\"lora_alpha\"][-1])\n",
+ "\n",
+ " # Train the model with the current hyperparameters\n",
+ " try:\n",
+ " trainer_stats = trainer_test.train() # Assuming this trains the model\n",
+ " return trainer_stats[\"train_loss\"] # Assuming this is the metric to minimize\n",
+ " except Exception as e:\n",
+ " return float(\"inf\") # Assign a high value if training fails\n",
+ "\n",
+ "study = create_study(direction=\"minimize\")\n",
+ "study.optimize(objective, n_trials=2) # Adjust the number of trials\n",
+ "\n",
+ "# Access the best trial and its hyperparameters after optimization\n",
+ "best_trial = study.best_trial\n",
+ "best_params = best_trial.params\n",
+ "\n",
+ "print(\"Best Trial:\", best_trial.number)\n",
+ "print(\"Best Hyperparameters:\", best_params)\n",
+ "print(\"Best Training Loss:\", best_trial.value)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-84LcTvQ_xtH"
+ },
+ "source": [
+ "## Analyzing Hyperparameters:\n",
+ "\n",
+ "* **Batch Size**: Generally, increasing the batch size can improve\n",
+ "\n",
+ "\n",
+ "training speed by utilizing hardware resources more efficiently. However, there's a limit beyond which performance degrades. You can tune the batch size within a reasonable range (e.g., 2, 4, 8, 16) to see its impact.\n",
+ "* **Learning Rate**: A higher learning rate can accelerate training initially. But, a too high value can lead to unstable training and potentially slower convergence. Consider a range of learning rates (e.g., log-uniform distribution between 1e-5 and 1e-3) for exploration.\n",
+ "* **Gradient Accumulation Steps**: This technique accumulates gradients over multiple batches before updating model weights. It can help reduce memory requirements but might slow down training per epoch. Experiment with different accumulation steps (e.g., 1, 2, 4) to find a balance.\n",
+ "* **Optimizer Choice**: Some optimizers like Adam or SGD with momentum can be faster than others depending on the model and dataset. Explore different optimizers and their hyperparameters (e.g., momentum coefficient) to see if they lead to faster convergence.\n",
+ "## Additional Considerations:\n",
+ "\n",
+ "Early Stopping: Implement early stopping to automatically terminate training if the validation loss doesn't improve for a certain number of epochs. This can save training time if the model starts overfitting.\n",
+ "Warmup Steps: A gradual increase in the learning rate during the initial training phase (warmup steps) can improve stability and potentially accelerate convergence compared to a fixed learning rate from the beginning.\n",
+ "\n",
+ "\n",
+ "* Experimentation and Profiling:\n",
+ "\n",
+ "The best hyperparameters for faster training depend on your specific model, dataset, and hardware. You'll need to experiment with different configurations using tools like Hyperopt to find the optimal settings.\n",
+ "Consider using profiling tools to identify bottlenecks in your training pipeline. This can help you focus on optimizing specific parts of the training process that are most time-consuming.\n",
+ "By analyzing these hyperparameters and implementing techniques like early stopping and warmup steps, you can potentially achieve faster fine-tuning while maintaining good model performance."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SdhZf88L_xdk"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jRbfR2n1wZrt"
+ },
+ "outputs": [],
+ "source": [
+ "## Method 1b Speed"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 636
+ },
+ "id": "uf-zwbRPteGH",
+ "outputId": "33a23501-2a5b-4c7d-faf8-e97e4c653811"
+ },
+ "outputs": [],
+ "source": [
+ "from optuna import create_study, Trial\n",
+ "import time # Assuming you can use time.time() to measure training time\n",
+ "\n",
+ "# Define search space with additional hyperparameter\n",
+ "search_space = {\n",
+ " \"learning_rate\": [1e-5, 5e-5, 1e-4, 2e-4],\n",
+ " \"per_device_train_batch_size\": [2, 4, 8],\n",
+ " \"lora_alpha\": [8, 16, 32],\n",
+ " \"gradient_accumulation_steps\": [1, 2, 4, 8], # Added gradient accumulation steps\n",
+ "}\n",
+ "\n",
+ "def objective(trial):\n",
+ " # Set hyperparameters based on trial values\n",
+ " config[\"training_config\"][\"learning_rate\"] = trial.suggest_float(\"learning_rate\", search_space[\"learning_rate\"][0], search_space[\"learning_rate\"][-1])\n",
+ " config[\"training_config\"][\"per_device_train_batch_size\"] = trial.suggest_int(\"per_device_train_batch_size\", search_space[\"per_device_train_batch_size\"][0], search_space[\"per_device_train_batch_size\"][-1])\n",
+ " config[\"training_config\"][\"gradient_accumulation_steps\"] = trial.suggest_int(\"gradient_accumulation_steps\", search_space[\"gradient_accumulation_steps\"][0], search_space[\"gradient_accumulation_steps\"][-1])\n",
+ " config[\"lora_config\"][\"lora_alpha\"] = trial.suggest_int(\"lora_alpha\", search_space[\"lora_alpha\"][0], search_space[\"lora_alpha\"][-1])\n",
+ "\n",
+ " # Train the model with the current hyperparameters\n",
+ " start_time = time.time()\n",
+ " try:\n",
+ " trainer_stats = trainer_test.train()\n",
+ " training_time = time.time() - start_time\n",
+ " return training_time # Minimize training time\n",
+ " except Exception as e:\n",
+ " return float(\"inf\") # Assign a high value if training fails\n",
+ "\n",
+ "study = create_study(direction=\"minimize\")\n",
+ "study.optimize(objective, n_trials=2) # Adjust the number of trials\n",
+ "\n",
+ "# Access the best trial and its hyperparameters after optimization\n",
+ "best_trial = study.best_trial\n",
+ "best_params = best_trial.params\n",
+ "\n",
+ "print(\"Best Trial:\", best_trial.number)\n",
+ "print(\"Best Hyperparameters (Likely Fastest):\", best_params)\n",
+ "print(\"Best Training Time:\", best_trial.value, \"seconds\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1Vz6NAbxxxlM",
+ "outputId": "ec166d41-fa4d-40a3-df3e-f44890010f07"
+ },
+ "outputs": [],
+ "source": [
+ "import hyperopt\n",
+ "from hyperopt import hp\n",
+ "from hyperopt import Trials\n",
+ "from hyperopt import fmin, tpe, Trials\n",
+ "# Define the search space for hyperparameters\n",
+ "space = {\n",
+ " 'learning_rate': hp.loguniform('learning_rate', -5, -1), # Learning rate in log scale\n",
+ " 'lora_alpha': hp.quniform('lora_alpha', 1, 32, 1), # LoRA alpha with quantized steps\n",
+ " 'lora_dropout': hp.uniform('lora_dropout', 0, 0.5), # LoRA dropout rate\n",
+ " # Uncomment these if you want to tune them\n",
+ " # 'per_device_train_batch_size': hp.quniform('per_device_train_batch_size', 2, 16, 1),\n",
+ " # 'gradient_accumulation_steps': hp.quniform('gradient_accumulation_steps', 1, 8, 1),\n",
+ " # 'warmup_steps': hp.quniform('warmup_steps', 0, 1000, 1),\n",
+ " # 'num_train_epochs': hp.quniform('num_train_epochs', 1, 5, 1),\n",
+ "}\n",
+ "def objective(params):\n",
+ " # Set hyperparameters in the config dictionary (assuming it's defined elsewhere)\n",
+ " config['training_config']['learning_rate'] = params['learning_rate']\n",
+ " config['lora_config']['lora_alpha'] = params['lora_alpha']\n",
+ " config['lora_config']['lora_dropout'] = params['lora_dropout']\n",
+ " # ... Set other hyperparameters from params dictionary ...\n",
+ " #config['training_config']['per_device_train_batch_size'] = params['per_device_train_batch_size']\n",
+ " #config['training_config']['gradient_accumulation_steps'] = params['gradient_accumulation_steps']\n",
+ " #config['training_config']['warmup_steps'] = params['warmup_steps']\n",
+ " #config['training_config']['num_train_epochs'] = params['num_train_epochs']\n",
+ "\n",
+ " # Load the model and tokenizer (assuming these are defined elsewhere)\n",
+ " try:\n",
+ " model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name=config.get(\"model_config\").get(\"base_model\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype=config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit=config.get(\"model_config\").get(\"load_in_4bit\"),\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error loading model and tokenizer: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ "\n",
+ " # Setup LoRA for the model (assuming FastLanguageModel supports LoRA)\n",
+ " try:\n",
+ " model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r=config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules=config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha=params['lora_alpha'],\n",
+ " lora_dropout=params['lora_dropout'],\n",
+ " bias=config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing=config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state=42,\n",
+ " use_rslora=config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora=config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config=config.get(\"lora_config\").get(\"loftq_config\")\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error setting up LoRA: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ " # Train the model on the test dataset (assuming SFTTrainer and training arguments are defined)\n",
+ " try:\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " tokenizer=tokenizer,\n",
+ " train_dataset=test_dataset,\n",
+ " dataset_text_field=config_dataset.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc=2,\n",
+ " packing=False,\n",
+ " args=TrainingArguments(\n",
+ " per_device_train_batch_size=int(params['per_device_train_batch_size']),\n",
+ " gradient_accumulation_steps=params['gradient_accumulation_steps'],\n",
+ " warmup_steps=params['warmup_steps'],\n",
+ " max_steps=config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs=params['num_train_epochs'],\n",
+ " learning_rate=params['learning_rate'],\n",
+ " fp16=config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16=config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps=config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim=config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay=config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type=config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed=42,\n",
+ " output_dir=config.get(\"training_config\").get(\"output_dir\")\n",
+ " )\n",
+ " )\n",
+ " trainer_stats = trainer.train()\n",
+ " return trainer_stats.loss # Assuming loss is the metric to minimize\n",
+ " except Exception as e:\n",
+ " print(f\"Error during training: {e}\")\n",
+ " return float(\"inf\") # Return high value for failed trials\n",
+ "\n",
+ "# Create a Trials object to track hyperparameter evaluations\n",
+ "trials = Trials()\n",
+ "\n",
+ "# Run hyperparameter optimization using TPE algorithm\n",
+ "best = fmin(objective, space, algo=tpe.suggest, trials=trials, max_evals=2)\n",
+ "\n",
+ "# Print the best hyperparameters found during optimization\n",
+ "print(\"Best Hyperparameters:\", best)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "RnjkQJ852_c2",
+ "outputId": "6dd97901-3c17-4d46-d04f-39a7a335a760"
+ },
+ "outputs": [],
+ "source": [
+ "import hyperopt\n",
+ "from hyperopt import hp\n",
+ "from hyperopt import Trials\n",
+ "from hyperopt import fmin, tpe, Trials\n",
+ "\n",
+ "# Define the search space for hyperparameters with uncommented additions\n",
+ "space = {\n",
+ " 'learning_rate': hp.loguniform('learning_rate', -5, -1), # Learning rate in log scale\n",
+ " 'lora_alpha': hp.quniform('lora_alpha', 1, 32, 1), # LoRA alpha with quantized steps\n",
+ " 'lora_dropout': hp.uniform('lora_dropout', 0, 0.5), # LoRA dropout rate\n",
+ " 'per_device_train_batch_size': hp.quniform('per_device_train_batch_size', 2, 16, 1), # Added for exploration\n",
+ " 'gradient_accumulation_steps': hp.quniform('gradient_accumulation_steps', 1, 8, 1), # Added for exploration\n",
+ " # Uncomment these if you want to tune other hyperparameters\n",
+ " # 'warmup_steps': hp.quniform('warmup_steps', 0, 1000, 1),\n",
+ " # 'num_train_epochs': hp.quniform('num_train_epochs', 1, 5, 1),\n",
+ "}\n",
+ "\n",
+ "\n",
+ "def objective(params):\n",
+ " # Set hyperparameters in the config dictionary (assuming it's defined elsewhere)\n",
+ " config['training_config']['learning_rate'] = params['learning_rate']\n",
+ " config['lora_config']['lora_alpha'] = params['lora_alpha']\n",
+ " config['lora_config']['lora_dropout'] = params['lora_dropout']\n",
+ " config['training_config']['per_device_train_batch_size'] = params['per_device_train_batch_size']\n",
+ " config['training_config']['gradient_accumulation_steps'] = params['gradient_accumulation_steps']\n",
+ " # ... Set other hyperparameters from params dictionary ...\n",
+ "\n",
+ " # Load the model and tokenizer (assuming these are defined elsewhere)\n",
+ " try:\n",
+ " model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name=config.get(\"model_config\").get(\"base_model\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype=config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit=config.get(\"model_config\").get(\"load_in_4bit\"),\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error loading model and tokenizer: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ "\n",
+ " # Setup LoRA for the model (assuming FastLanguageModel supports LoRA)\n",
+ " try:\n",
+ " model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r=config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules=config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha=params['lora_alpha'],\n",
+ " lora_dropout=params['lora_dropout'],\n",
+ " bias=config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing=config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state=42,\n",
+ " use_rslora=config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora=config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config=config.get(\"lora_config\").get(\"loftq_config\")\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error setting up LoRA: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ "\n",
+ " # Train the model on the test dataset (assuming SFTTrainer and training arguments are defined)\n",
+ " try:\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " tokenizer=tokenizer,\n",
+ " train_dataset=test_dataset,\n",
+ " dataset_text_field=config_dataset.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc=2,\n",
+ " packing=False,\n",
+ " args=TrainingArguments(\n",
+ " per_device_train_batch_size=int(params['per_device_train_batch_size']),\n",
+ " gradient_accumulation_steps=params['gradient_accumulation_steps'],\n",
+ " warmup_steps=params['warmup_steps'],\n",
+ " max_steps=config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs=params['num_train_epochs'],\n",
+ " learning_rate=params['learning_rate'],\n",
+ " fp16=config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16=config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps=config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim=config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay=config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type=config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed=42,\n",
+ " output_dir=config.get(\"training_config\").get(\"output_dir\")\n",
+ " )\n",
+ " )\n",
+ " trainer_stats = trainer.train()\n",
+ " return trainer_stats.loss # Assuming loss is the metric to minimize\n",
+ " except Exception as e:\n",
+ " print(f\"Error during training: {e}\")\n",
+ " return float(\"inf\") # Return high value for failed trials\n",
+ "\n",
+ "# Create a Trials object to track hyperparameter evaluations\n",
+ "trials = Trials()\n",
+ "\n",
+ "# Run hyperparameter optimization using TPE algorithm\n",
+ "best = fmin(objective, space, algo=tpe.suggest, trials=trials, max_evals=2)\n",
+ "\n",
+ "# Print the best hyperparameters found during optimization\n",
+ "print(\"Best Hyperparameters:\", best)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ID7nFKsV5urO"
+ },
+ "outputs": [],
+ "source": [
+ "## Method"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "xp6S8LGg4lUG",
+ "outputId": "537667cd-7711-4d10-e2da-d2337b80c43a"
+ },
+ "outputs": [],
+ "source": [
+ "import hyperopt\n",
+ "from hyperopt import hp\n",
+ "from hyperopt import Trials\n",
+ "from hyperopt import fmin, tpe, Trials\n",
+ "import time # Import time for measuring training duration\n",
+ "\n",
+ "# Define the search space for hyperparameters with uncommented additions\n",
+ "space = {\n",
+ " 'learning_rate': hp.loguniform('learning_rate', -5, -1), # Learning rate in log scale\n",
+ " 'lora_alpha': hp.quniform('lora_alpha', 1, 32, 1), # LoRA alpha with quantized steps\n",
+ " 'lora_dropout': hp.uniform('lora_dropout', 0, 0.5), # LoRA dropout rate\n",
+ " 'per_device_train_batch_size': hp.quniform('per_device_train_batch_size', 2, 16, 1), # Added for exploration\n",
+ " 'gradient_accumulation_steps': hp.quniform('gradient_accumulation_steps', 1, 8, 1), # Added for exploration\n",
+ " # Uncomment these if you want to tune other hyperparameters\n",
+ " # 'warmup_steps': hp.quniform('warmup_steps', 0, 1000, 1),\n",
+ " # 'num_train_epochs': hp.quniform('num_train_epochs', 1, 5, 1),\n",
+ "}\n",
+ "\n",
+ "\n",
+ "def objective(params):\n",
+ " # Set hyperparameters in the config dictionary (assuming it's defined elsewhere)\n",
+ " config['training_config']['learning_rate'] = params['learning_rate']\n",
+ " config['lora_config']['lora_alpha'] = params['lora_alpha']\n",
+ " config['lora_config']['lora_dropout'] = params['lora_dropout']\n",
+ " config['training_config']['per_device_train_batch_size'] = params['per_device_train_batch_size']\n",
+ " config['training_config']['gradient_accumulation_steps'] = params['gradient_accumulation_steps']\n",
+ " # ... Set other hyperparameters from params dictionary ...\n",
+ "\n",
+ " # Load the model and tokenizer (assuming these are defined elsewhere)\n",
+ " try:\n",
+ " model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name=config.get(\"model_config\").get(\"base_model\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype=config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit=config.get(\"model_config\").get(\"load_in_4bit\"),\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error loading model and tokenizer: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ "\n",
+ " # Setup LoRA for the model (assuming FastLanguageModel supports LoRA)\n",
+ " try:\n",
+ " model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r=config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules=config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha=params['lora_alpha'],\n",
+ " lora_dropout=params['lora_dropout'],\n",
+ " bias=config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing=config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state=42,\n",
+ " use_rslora=config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora=config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config=config.get(\"lora_config\").get(\"loftq_config\")\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error setting up LoRA: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ "\n",
+ " # Train the model on the test dataset (assuming SFTTrainer and training arguments are defined)\n",
+ " try:\n",
+ " start_time = time.time() # Measure training start time\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " tokenizer=tokenizer,\n",
+ " train_dataset=test_dataset,\n",
+ " dataset_text_field=config_dataset.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc=2,\n",
+ " packing=False,\n",
+ " args=TrainingArguments(\n",
+ " per_device_train_batch_size=int(params['per_device_train_batch_size']),\n",
+ " gradient_accumulation_steps=params['gradient_accumulation_steps'],\n",
+ " warmup_steps=params['warmup_steps'],\n",
+ " max_steps=config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs=params['num_train_epochs'],\n",
+ " learning_rate=params['learning_rate'],\n",
+ " fp16=config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16=config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps=config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim=config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay=config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type=config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed=42,\n",
+ " output_dir=config.get(\"training_config\").get(\"output_dir\")\n",
+ " )\n",
+ " )\n",
+ " trainer_stats = trainer.train()\n",
+ " end_time = time.time() # Measure training end time\n",
+ " training_time = end_time - start_time # Calculate training time\n",
+ "\n",
+ " return training_time # Return training time for minimization\n",
+ " except Exception as e:\n",
+ " print(f\"Error during training: {e}\")\n",
+ " return float(\"inf\") # Return high value for failed trials\n",
+ "\n",
+ "# Create a Trials object to track hyperparameter evaluations\n",
+ "trials = Trials()\n",
+ "\n",
+ "# Run hyperparameter optimization using TPE algorithm\n",
+ "best = fmin(objective, space, algo=tpe.suggest, trials=trials, max_evals=2)\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1LIpTKWI5NTV",
+ "outputId": "930484d7-c820-4cd3-80ed-f74ae6761346"
+ },
+ "outputs": [],
+ "source": [
+ "# Print the best hyperparameters found during optimization\n",
+ "print(\"Best Hyperparameters:\", best)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Y70d0UUS5Izr"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vKqIDJIGYV11"
+ },
+ "source": [
+ "# Hyperparameter search\n",
+ "**Step 1: Define the Hyperparameter Search Space**\n",
+ "We need to define the search space for the hyperparameters we want to tune. For example, let's say we want to tune the following hyperparameters:\n",
+ "\n",
+ "* `learning_rate`\n",
+ "* `per_device_train_batch_size`\n",
+ "* `gradient_accumulation_steps`\n",
+ "* `warmup_steps`\n",
+ "* `num_train_epochs`\n",
+ "* `lora_alpha`\n",
+ "* `lora_dropout`\n",
+ "\n",
+ "We can define the search space as follows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ETCIc-5JYvEq"
+ },
+ "outputs": [],
+ "source": [
+ "import hyperopt\n",
+ "from hyperopt import hp\n",
+ "from hyperopt import Trials\n",
+ "from hyperopt import fmin, tpe, Trials\n",
+ "# Define the search space for hyperparameters\n",
+ "space = {\n",
+ " 'learning_rate': hp.loguniform('learning_rate', -5, -1), # Learning rate in log scale\n",
+ " 'lora_alpha': hp.quniform('lora_alpha', 1, 32, 1), # LoRA alpha with quantized steps\n",
+ " 'lora_dropout': hp.uniform('lora_dropout', 0, 0.5), # LoRA dropout rate\n",
+ " # Uncomment these if you want to tune them\n",
+ " # 'per_device_train_batch_size': hp.quniform('per_device_train_batch_size', 2, 16, 1),\n",
+ " # 'gradient_accumulation_steps': hp.quniform('gradient_accumulation_steps', 1, 8, 1),\n",
+ " # 'warmup_steps': hp.quniform('warmup_steps', 0, 1000, 1),\n",
+ " # 'num_train_epochs': hp.quniform('num_train_epochs', 1, 5, 1),\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "t1i7r2glY2Df"
+ },
+ "source": [
+ "**Step 2. Define the Objective Function**\n",
+ "\n",
+ "The objective function is a function that takes in the hyperparameters, sets them in the `config` dictionary, trains the model, and returns the loss or metric to minimize. We need to modify the previous fine-tuning code to define the objective function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mUTbsbJQb08e"
+ },
+ "outputs": [],
+ "source": [
+ "def objective(params):\n",
+ " # Set hyperparameters in the config dictionary (assuming it's defined elsewhere)\n",
+ " config['training_config']['learning_rate'] = params['learning_rate']\n",
+ " config['lora_config']['lora_alpha'] = params['lora_alpha']\n",
+ " config['lora_config']['lora_dropout'] = params['lora_dropout']\n",
+ " # ... Set other hyperparameters from params dictionary ...\n",
+ " #config['training_config']['per_device_train_batch_size'] = params['per_device_train_batch_size']\n",
+ " #config['training_config']['gradient_accumulation_steps'] = params['gradient_accumulation_steps']\n",
+ " #config['training_config']['warmup_steps'] = params['warmup_steps']\n",
+ " #config['training_config']['num_train_epochs'] = params['num_train_epochs']\n",
+ "\n",
+ " # Load the model and tokenizer (assuming these are defined elsewhere)\n",
+ " try:\n",
+ " model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name=config.get(\"model_config\").get(\"base_model\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype=config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit=config.get(\"model_config\").get(\"load_in_4bit\"),\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error loading model and tokenizer: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ "\n",
+ " # Setup LoRA for the model (assuming FastLanguageModel supports LoRA)\n",
+ " try:\n",
+ " model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r=config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules=config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha=params['lora_alpha'],\n",
+ " lora_dropout=params['lora_dropout'],\n",
+ " bias=config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing=config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state=42,\n",
+ " use_rslora=config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora=config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config=config.get(\"lora_config\").get(\"loftq_config\")\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error setting up LoRA: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ " # Train the model on the test dataset (assuming SFTTrainer and training arguments are defined)\n",
+ " try:\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " tokenizer=tokenizer,\n",
+ " train_dataset=test_dataset,\n",
+ " dataset_text_field=config_dataset.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc=2,\n",
+ " packing=False,\n",
+ " args=TrainingArguments(\n",
+ " per_device_train_batch_size=int(params['per_device_train_batch_size']),\n",
+ " gradient_accumulation_steps=params['gradient_accumulation_steps'],\n",
+ " warmup_steps=params['warmup_steps'],\n",
+ " max_steps=config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs=params['num_train_epochs'],\n",
+ " learning_rate=params['learning_rate'],\n",
+ " fp16=config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16=config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps=config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim=config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay=config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type=config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed=42,\n",
+ " output_dir=config.get(\"training_config\").get(\"output_dir\")\n",
+ " )\n",
+ " )\n",
+ " trainer_stats = trainer.train()\n",
+ " return trainer_stats.loss # Assuming loss is the metric to minimize\n",
+ " except Exception as e:\n",
+ " print(f\"Error during training: {e}\")\n",
+ " return float(\"inf\") # Return high value for failed trials\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z7od3txJaZbm"
+ },
+ "source": [
+ "**Step 3: Perform Hyperparameter Search**\n",
+ "\n",
+ "Now that we have defined the objective function, we can perform the hyperparameter search using Hyperopt's `fmin` function. We need to specify the objective function, the search space, and the maximum number of evaluations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "vLTpYBVzbpmP",
+ "outputId": "ce2d2b57-2e40-4ae8-ec20-880b78be3a56"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "# Create a Trials object to track hyperparameter evaluations\n",
+ "trials = Trials()\n",
+ "# Run hyperparameter optimization using TPE algorithm\n",
+ "best = fmin(objective, space, algo=tpe.suggest, trials=trials, max_evals=2)\n",
+ "# Print the best hyperparameters found during optimization\n",
+ "print(\"Best Hyperparameters:\", best)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "FYO9wV8IoXpy"
+ },
+ "outputs": [],
+ "source": [
+ "from huggingface_hub import login, logout"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "z03TocnqoXpy",
+ "outputId": "c598ea52-e319-41ed-cc51-935f61201178"
+ },
+ "outputs": [],
+ "source": [
+ "#login(token) # non-blocking login"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dwDh_WpSoXpy",
+ "outputId": "28fbd65b-61fd-433c-df4f-819a06d4ba05"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import gc\n",
+ "def reset_gpu_memory():\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " print(\"GPU memory cleared!\")\n",
+ "# Example usage:\n",
+ "reset_gpu_memory()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yhpc3w89A3A_"
+ },
+ "source": [
+ "Best Hyperparameters: {'learning_rate': 0.03347123299210303, 'lora_alpha': 19.0, 'lora_dropout': 0.4819141472093197}\n",
+ "\n",
+ "Best Hyperparameters: {'gradient_accumulation_steps': 8.0, 'learning_rate': 0.23274337759179295, 'lora_alpha': 8.0, 'lora_dropout': 0.0491660925212421, 'per_device_train_batch_size': 13.0}\n",
+ "\n",
+ "Best Hyperparameters: {'gradient_accumulation_steps': 4.0, 'learning_rate': 0.186066529001672, 'lora_alpha': 32.0, 'lora_dropout': 0.24368804023352264, 'per_device_train_batch_size': 10.0}\n",
+ "\n",
+ "Best Hyperparameters: {'learning_rate': 0.011846192509972951, 'lora_alpha': 8.0, 'lora_dropout': 0.2087248476879589}\n",
+ "\n",
+ "\n",
+ "\n",
+ "Best Hyperparameters (Likely Fastest): {'learning_rate': 1.881999040862022e-05, 'per_device_train_batch_size': 2, 'gradient_accumulation_steps': 2, 'lora_alpha': 29}\n",
+ "Best Training Time: 48.178661584854126 seconds\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Oh4LwgiZ6d3L"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qQQN-uSUzB8h"
+ },
+ "outputs": [],
+ "source": [
+ "# Defining the configuration for the base model, LoRA and training\n",
+ "config = {\n",
+ " \"hugging_face_username\":\"ruslanmv\",\n",
+ " \"model_config\": {\n",
+ " \"base_model\":\"meta-llama/Meta-Llama-3-8B-Instruct\", # The base model\n",
+ " \"finetuned_model\":\"ruslanmv/Medical-Mind-Llama-3-8b\", # The finetuned model\n",
+ " \"max_seq_length\": 2048, # The maximum sequence length\n",
+ " # \"dtype\":torch.float16, # The data type\n",
+ " # \"dtype\": torch.float32, # Use float32 instead of half CUDA capability < 8\n",
+ " \"dtype\" : None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
+ "\n",
+ " \"load_in_4bit\": True, # Load the model in 4-bit\n",
+ " },\n",
+ " \"lora_config\": {\n",
+ " \"r\": 16, # The number of LoRA layers 8, 16, 32, 64\n",
+ " \"target_modules\": [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
+ " \"gate_proj\", \"up_proj\", \"down_proj\"], # The target modules\n",
+ " #\"lora_alpha\":16, # The alpha value for LoRA\n",
+ " \"lora_alpha\":29, # The alpha value for LoRA by search grid\n",
+ " \"lora_dropout\":0, # The dropout value for LoRA\n",
+ " \"bias\":\"none\", # The bias for LoRA\n",
+ " \"use_gradient_checkpointing\":True, # Use gradient checkpointing\n",
+ " \"use_rslora\":False, # Use RSLora\n",
+ " \"use_dora\":False, # Use DoRa\n",
+ " \"loftq_config\":None # The LoFTQ configuration\n",
+ " },\n",
+ "\n",
+ " \"training_config\": {\n",
+ " #\"per_device_train_batch_size\": 2, # The batch size\n",
+ " \"per_device_train_batch_size\": 2, # The batch size by search grid\n",
+ " #\"gradient_accumulation_steps\": 4, # The gradient accumulation steps\n",
+ " \"gradient_accumulation_steps\": 2, # The gradient accumulation steps by search grid\n",
+ " \"warmup_steps\": 5, # The warmup steps\n",
+ " \"max_steps\":0, # The maximum steps (0 if the epochs are defined)\n",
+ " \"num_train_epochs\": 1, # The number of training epochs(0 if the maximum steps are defined)\n",
+ " #\"learning_rate\": 2e-4, # The learning rate\n",
+ " \"learning_rate\": 1.88e-05, # The learning rate by search grid\n",
+ " \"fp16\": not torch.cuda.is_bf16_supported(), # The fp16\n",
+ " \"bf16\": torch.cuda.is_bf16_supported(), # The bf16\n",
+ " \"logging_steps\": 1, # The logging steps\n",
+ " \"optim\" :\"adamw_8bit\", # The optimizer\n",
+ " \"weight_decay\" : 0.01, # The weight decay\n",
+ " \"lr_scheduler_type\": \"linear\", # The learning rate scheduler\n",
+ " \"seed\" : 42, # The seed\n",
+ " \"output_dir\" : \"outputs\", # The output directory\n",
+ " }\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "bX5eLb3Ss-39"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "LEMLkJAeoXpy",
+ "outputId": "6814c878-7716-4201-c287-3a02b4ce6f62"
+ },
+ "outputs": [],
+ "source": [
+ "# Loading the model and the tokinizer for the model\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = config.get(\"model_config\").get(\"base_model\"),\n",
+ " max_seq_length = config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype = config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit = config.get(\"model_config\").get(\"load_in_4bit\"),\n",
+ "\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "kvValJd0oXpz",
+ "outputId": "b2aa892b-92f3-4c4c-9688-781069b48585"
+ },
+ "outputs": [],
+ "source": [
+ "# Set up GPU acceleration\n",
+ "if torch.cuda.device_count() > 1:\n",
+ " print(\"Multiple GPUs enabled\")\n",
+ " devices = [f\"cuda:{i}\" for i in range(torch.cuda.device_count())]\n",
+ " model_parallel = torch.nn.DataParallel(model, device_ids=[0, 1])\n",
+ " # Access the original model from the DataParallel object\n",
+ " model = model_parallel.module\n",
+ "else:\n",
+ " print(\"No DataParallel \")\n",
+ " #device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DrQUZ0jtoXpz"
+ },
+ "outputs": [],
+ "source": [
+ "#model = model.half() # the model to half precision (float16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZPLenE03oXpz"
+ },
+ "outputs": [],
+ "source": [
+ "# Setup for QLoRA/LoRA peft of the base model\n",
+ "model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r = config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules = config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha = config.get(\"lora_config\").get(\"lora_alpha\"),\n",
+ " lora_dropout = config.get(\"lora_config\").get(\"lora_dropout\"),\n",
+ " bias = config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing = config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state = 42,\n",
+ " use_rslora = config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora = config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config = config.get(\"lora_config\").get(\"loftq_config\"),\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dCg9_NAfoXpz",
+ "outputId": "75303ebe-3299-496e-faea-afa89f4e4c01"
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n",
+ "tokenizer = AutoTokenizer.from_pretrained(config.get(\"model_config\").get(\"base_model\"))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "V4soybR7oXpz"
+ },
+ "outputs": [],
+ "source": [
+ "tokenizer.add_eos_token = True\n",
+ "tokenizer.pad_token_id = 0\n",
+ "tokenizer.padding_side = \"left\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_sm6yQFPWNXY"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "M0qTbrawoXpz"
+ },
+ "outputs": [],
+ "source": [
+ "config_dataset={ \"training_dataset\": {\n",
+ " \"name\": \"ruslanmv/ai-medical-dataset\", # The dataset name(huggingface/datasets)\n",
+ " \"split\": \"train\", # The dataset split\n",
+ " \"input_fields\": [\"question\", \"context\"] ,# The input fields\n",
+ " \"input_field\": \"text\",# The input field\n",
+ " },\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "DMscf5cdoXpz",
+ "outputId": "c7799645-c6bf-4b31-8070-1aa0ef60df33"
+ },
+ "outputs": [],
+ "source": [
+ "config_dataset.get(\"training_dataset\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 113,
+ "referenced_widgets": [
+ "3a97281be4c1433aa3abe6c25b7113e2",
+ "4e19e78059b842a5832ccae2f765a30c",
+ "1a72b512e1374e67a858edf2844fc157",
+ "c9cfd66b68a1437d946c83163fa877df",
+ "cccd970273ae43d2a6e60ac421bdc882",
+ "32cff795f8bc490dbf63ed130e1f581f",
+ "4a0426a353ca41cba39d4dfeba925451",
+ "284192f01a924f87afd8b5087ca9af6c",
+ "273bf76f74bc4fb492ccb67d9e202f7b",
+ "45b3259e3cac4de8bd19d12f07de2adb",
+ "b7e7896aeac74b6eae27de0677100e57",
+ "11dc1dcf6b29471580c32c818fa41d88",
+ "9344b22940c64654a82bb2ce06530e30",
+ "4f68a26f64e844c7be21cc180eb6c1a2",
+ "769b40273bab41af8eb66e494b613241",
+ "320c09781518483e82defa86c28316d1",
+ "793f49f397b54daab63194cee8d04256",
+ "fa79cfa23f3a430dab69a59d93383cd0",
+ "341dca5ac74348dd9b5a347e38fa0b40",
+ "8ba6fd1bf16a4680b8a8c9c55ecf23e7",
+ "dc85f5e365f4488fa185d0ae35fde806",
+ "51a6d3c97480476e8c22d9ad670bdc47",
+ "b8b277831f1a45109b3a4a3565fbdb9d",
+ "9f91f7ce62e243f59d72e5ba36f97b8f",
+ "1634ba52355b4681a913039666926f85",
+ "217ca5cd404d4756a399fba3aa4fbc15",
+ "bc6d92cb8837428bb7038d75e6af604e",
+ "af0233735d744b7e838f50f52c9d6cbe",
+ "8a8d3a006ee24c4393d7c2f2d040ce52",
+ "eff94d2d010e4b4f93a6dfcb61103a52",
+ "da5cd094aaae45f4a0ca051ad5babd78",
+ "8f88a5b04723482ea430679e504c65f9",
+ "8d153f070a8d4ad1b32996a9fd82beda"
+ ]
+ },
+ "id": "g2h5E--2oXp0",
+ "outputId": "8b8a3e49-a0bd-4aea-bf8d-5d2414f66547"
+ },
+ "outputs": [],
+ "source": [
+ "# Loading the training dataset\n",
+ "train_dataset = load_dataset(config_dataset.get(\"training_dataset\").get(\"name\"), split = config_dataset.get(\"training_dataset\").get(\"split\"))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "nxGRQ9sCoXp0",
+ "outputId": "1e2ce893-2a39-4521-9062-490a9e9de016"
+ },
+ "outputs": [],
+ "source": [
+ "train_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "sGtF6NvpoXp0"
+ },
+ "outputs": [],
+ "source": [
+ "# Select the first 10 rows of the dataset\n",
+ "test_dataset = train_dataset.select(range(100))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "LipkaBaBoXp0",
+ "outputId": "4484e420-1693-4524-bf2a-19db669c5543"
+ },
+ "outputs": [],
+ "source": [
+ "test_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "bBIK97mjoXp0",
+ "outputId": "930efcd4-5b32-4a8d-ff68-859b63293e7e"
+ },
+ "outputs": [],
+ "source": [
+ "test_dataset[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HERoJEG2oXp0"
+ },
+ "outputs": [],
+ "source": [
+ "medical_prompt = \"\"\"You are an AI Medical Assistant Chatbot, trained to answer medical questions. Below is an instruction that describes a task, paired with an response context. Write a response that appropriately completes the request.\n",
+ "\n",
+ "### Instruction:\n",
+ "{}\n",
+ "\n",
+ "\n",
+ "### Response:\n",
+ "{}\"\"\"\n",
+ "\n",
+ "EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN\n",
+ "def formatting_prompts_func(examples):\n",
+ " instructions = examples[\"question\"]\n",
+ " outputs = examples[\"context\"]\n",
+ " texts = []\n",
+ " for instruction, output in zip(instructions, outputs):\n",
+ " # Must add EOS_TOKEN, otherwise your generation will go on forever!\n",
+ " text = medical_prompt.format(instruction, output) + EOS_TOKEN\n",
+ " texts.append(text)\n",
+ " return { \"text\" : texts, }\n",
+ "pass"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Redv7cdFoXp0"
+ },
+ "outputs": [],
+ "source": [
+ "test_dataset= test_dataset.map(formatting_prompts_func, batched = True,)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "CSGjoG8voXp0",
+ "outputId": "92689d46-6795-4591-bc5e-211c8cc9797a"
+ },
+ "outputs": [],
+ "source": [
+ "test_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 142
+ },
+ "id": "z5Q2wwfjoXp0",
+ "outputId": "a76d49b3-48bf-43d6-fc80-a5aa88a12634"
+ },
+ "outputs": [],
+ "source": [
+ "test_dataset['text'][1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "urBn0WMSoXp0"
+ },
+ "outputs": [],
+ "source": [
+ "is_test=True\n",
+ "if is_test:\n",
+ " train_dataset=test_dataset\n",
+ "else:\n",
+ " train_dataset= train_dataset.map(formatting_prompts_func, batched = True,)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 142
+ },
+ "id": "UevHEqo7oXp0",
+ "outputId": "dfe3869c-fd9e-4734-9462-e9eb391792e1"
+ },
+ "outputs": [],
+ "source": [
+ "train_dataset['text'][1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 140,
+ "referenced_widgets": [
+ "ffa74977e7464cebb16d3cf8ee976d51",
+ "e257e4a2bfdb48038102173d397ab2e4",
+ "67b9a3505ae644dbb3c4fc14781a2731",
+ "c4d39c87c16c4961b942d896742ff7ce",
+ "e5880b946aae4b84a94226a5d6acaf45",
+ "82c6c2752a0746f3935e069c0f8811d6",
+ "1850ab17bafd4a43b5ab5899d1875a40",
+ "53ee8f5e8b7d4076bdb0167baf2e5729",
+ "d70fd9035f9b4d82892fae34c28c46d5",
+ "af0096de28414303ba5324f4087cd92e",
+ "0f55ae30c2704632941cca4727c1c4f2"
+ ]
+ },
+ "id": "X4wxJAgnM2W0",
+ "outputId": "38c58ce9-6f4c-49c9-e21d-49bc34f5cc2e"
+ },
+ "outputs": [],
+ "source": [
+ "# Setting up the trainer for the model\n",
+ "trainer = SFTTrainer(\n",
+ " model = model,\n",
+ " tokenizer = tokenizer,\n",
+ " train_dataset = train_dataset,\n",
+ " dataset_text_field = config_dataset.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length = config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc = 2,\n",
+ " packing = False,\n",
+ " args = TrainingArguments(\n",
+ " per_device_train_batch_size = config.get(\"training_config\").get(\"per_device_train_batch_size\"),\n",
+ " gradient_accumulation_steps = config.get(\"training_config\").get(\"gradient_accumulation_steps\"),\n",
+ " warmup_steps = config.get(\"training_config\").get(\"warmup_steps\"),\n",
+ " max_steps = config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs= config.get(\"training_config\").get(\"num_train_epochs\"),\n",
+ " learning_rate = config.get(\"training_config\").get(\"learning_rate\"),\n",
+ " fp16 = config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16 = config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps = config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim = config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay = config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type = config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed = 42,\n",
+ " output_dir = config.get(\"training_config\").get(\"output_dir\"),\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "RJl7dk8yoXp1",
+ "outputId": "b37c70bb-50a6-4319-b99c-c0d9c05f035b"
+ },
+ "outputs": [],
+ "source": [
+ "# Memory statistics before training\n",
+ "gpu_statistics = torch.cuda.get_device_properties(0)\n",
+ "reserved_memory = round(torch.cuda.max_memory_reserved() / 1024**3, 2)\n",
+ "max_memory = round(gpu_statistics.total_memory / 1024**3, 2)\n",
+ "print(f\"Reserved Memory: {reserved_memory}GB\")\n",
+ "print(f\"Max Memory: {max_memory}GB\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Q-g-4RvNXkyD"
+ },
+ "outputs": [],
+ "source": [
+ "## [ 1038/2651250 53:49 < 2295:10:28, 0.32 it/s, Epoch 0.00/1] old"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 948
+ },
+ "id": "yI9mEQ7ZOUx2",
+ "outputId": "6466d591-76f8-45e2-e665-39ad9bf8ae7f",
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "# Training the model\n",
+ "trainer_stats = trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YQFEr64koXp1",
+ "outputId": "2e1b3775-1f5d-4b8e-a0ef-32266cb7fa2a"
+ },
+ "outputs": [],
+ "source": [
+ "# Memory statistics after training\n",
+ "used_memory = round(torch.cuda.max_memory_allocated() / 1024**3, 2)\n",
+ "used_memory_lora = round(used_memory - reserved_memory, 2)\n",
+ "used_memory_persentage = round((used_memory / max_memory) * 100, 2)\n",
+ "used_memory_lora_persentage = round((used_memory_lora / max_memory) * 100, 2)\n",
+ "print(f\"Used Memory: {used_memory}GB ({used_memory_persentage}%)\")\n",
+ "print(f\"Used Memory for training(fine-tuning) LoRA: {used_memory_lora}GB ({used_memory_lora_persentage}%)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1YJB4bZyoXp1"
+ },
+ "outputs": [],
+ "source": [
+ "# Saving the trainer stats\n",
+ "with open(\"trainer_stats.json\", \"w\") as f:\n",
+ " json.dump(trainer_stats, f, indent=4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-1HtsRpVnHTj"
+ },
+ "outputs": [],
+ "source": [
+ "# Locally saving the model and pushing it to the Hugging Face Hub (only LoRA adapters)\n",
+ "model.save_pretrained(config.get(\"model_config\").get(\"finetuned_model\"))\n",
+ "model.push_to_hub(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer = tokenizer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yRO4pPP0oXp1"
+ },
+ "outputs": [],
+ "source": [
+ "# Saving the model using merged_16bit(float16), merged_4bit(int4) or quantization options(q8_0, q4_k_m, q5_k_m)...\n",
+ "model.save_pretrained_merged(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, save_method = \"merged_16bit\",)\n",
+ "model.push_to_hub_merged(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, save_method = \"merged_16bit\")\n",
+ "\n",
+ "model.save_pretrained_merged(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, save_method = \"merged_4bit\",)\n",
+ "model.push_to_hub_merged(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, save_method = \"merged_4bit\")\n",
+ "\n",
+ "model.save_pretrained_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer)\n",
+ "model.push_to_hub_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer)\n",
+ "\n",
+ "model.save_pretrained_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, quantization_method = \"f16\")\n",
+ "model.push_to_hub_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, quantization_method = \"f16\")\n",
+ "\n",
+ "model.save_pretrained_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, quantization_method = \"q4_k_m\")\n",
+ "model.push_to_hub_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, quantization_method = \"q4_k_m\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ozVcalyP_JLs"
+ },
+ "outputs": [],
+ "source": [
+ "# Loading the fine-tuned model and the tokenizer for inference\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = config.get(\"model_config\").get(\"finetuned_model\"),\n",
+ " max_seq_length = config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype = config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit = config.get(\"model_config\").get(\"load_in_4bit\"),\n",
+ " )\n",
+ "\n",
+ "# Using FastLanguageModel for fast inference\n",
+ "FastLanguageModel.for_inference(model)\n",
+ "\n",
+ "# Tokenizing the input and generating the output\n",
+ "inputs = tokenizer(\n",
+ "[\n",
+ " \"<|start_header_id|>system<|end_header_id|> You are a Medical AI chatbot assistant .<|eot_id|><|start_header_id|>user<|end_header_id|> This is the question: What was the main cause of the inflammatory CD4+ T cells?<|eot_id|>\"\n",
+ "], return_tensors = \"pt\").to(\"cuda\")\n",
+ "outputs = model.generate(**inputs, max_new_tokens = 256, use_cache = True)\n",
+ "tokenizer.batch_decode(outputs, skip_special_tokens = True)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "L4",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "0005f2d9fe1e4cc98ea58b0c2868b433": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_45c1d5b0df0e420a87f791dd4cf0e425",
+ "max": 100,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_9ed49f1a099846a3a65cd6608bafb0e4",
+ "value": 100
+ }
+ },
+ "0058ed544fed4272848a891a68b9adc0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "00eea4b0c6e44c62900ea8e7d919efe9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "02fc530028ea4d538b7f6b48463ae700": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "058b2b9959b84b6f9f5d3862ef53d029": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7807f312425b4f4d9249aa1ac77d7461",
+ "placeholder": "",
+ "style": "IPY_MODEL_d8e7ea9552a84b8284b31d77090b54af",
+ "value": "Map (num_proc=2): 100%"
+ }
+ },
+ "0f55ae30c2704632941cca4727c1c4f2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "11dc1dcf6b29471580c32c818fa41d88": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_9344b22940c64654a82bb2ce06530e30",
+ "IPY_MODEL_4f68a26f64e844c7be21cc180eb6c1a2",
+ "IPY_MODEL_769b40273bab41af8eb66e494b613241"
+ ],
+ "layout": "IPY_MODEL_320c09781518483e82defa86c28316d1"
+ }
+ },
+ "1634ba52355b4681a913039666926f85": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_eff94d2d010e4b4f93a6dfcb61103a52",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_da5cd094aaae45f4a0ca051ad5babd78",
+ "value": 18
+ }
+ },
+ "1850ab17bafd4a43b5ab5899d1875a40": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "1a72b512e1374e67a858edf2844fc157": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_284192f01a924f87afd8b5087ca9af6c",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_273bf76f74bc4fb492ccb67d9e202f7b",
+ "value": 18
+ }
+ },
+ "217ca5cd404d4756a399fba3aa4fbc15": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_8f88a5b04723482ea430679e504c65f9",
+ "placeholder": "",
+ "style": "IPY_MODEL_8d153f070a8d4ad1b32996a9fd82beda",
+ "value": " 18/18 [00:00<00:00, 9.43it/s]"
+ }
+ },
+ "22ea45365d21439fb5069974bbe69711": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "23a71f8847e647daba35e495706fc846": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_22ea45365d21439fb5069974bbe69711",
+ "placeholder": "",
+ "style": "IPY_MODEL_bd087d0aa3214c5dbecc9b0bd4d976df",
+ "value": "Resolving data files: 100%"
+ }
+ },
+ "273bf76f74bc4fb492ccb67d9e202f7b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "284192f01a924f87afd8b5087ca9af6c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2c5564fb033346afbe7692a24a52b302": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "31a203cdd2f54cda8a05214844888156": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "320c09781518483e82defa86c28316d1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "32cff795f8bc490dbf63ed130e1f581f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "33fb10908c23457aa4796626102fc8c5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "341dca5ac74348dd9b5a347e38fa0b40": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3564e3cf0fe84281838d84525794e735": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_912164947c5847908424f3e60c5adb64",
+ "IPY_MODEL_7517ce80636040e29665a9353afab183",
+ "IPY_MODEL_e14b9d980a1a41fb9e81385cb0f73d3a"
+ ],
+ "layout": "IPY_MODEL_ada78aafba3f47ab8eb45cf3c83a6805"
+ }
+ },
+ "37803098ceed4528bb690ebee028c840": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "39d3b72ab6214bcf9b0bb6b6294e957c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3a97281be4c1433aa3abe6c25b7113e2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_4e19e78059b842a5832ccae2f765a30c",
+ "IPY_MODEL_1a72b512e1374e67a858edf2844fc157",
+ "IPY_MODEL_c9cfd66b68a1437d946c83163fa877df"
+ ],
+ "layout": "IPY_MODEL_cccd970273ae43d2a6e60ac421bdc882"
+ }
+ },
+ "3f7afd4bd28842cbb73e62c155667030": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9a5fd3a68fd1445f92bea51a7fec3e6b",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_37803098ceed4528bb690ebee028c840",
+ "value": 18
+ }
+ },
+ "44f189b81bbd48ca8cb146ead641d2b5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e903140c8c794c48b231924d3975b7a6",
+ "placeholder": "",
+ "style": "IPY_MODEL_7e74d789c82747e0b5066a00b9e36c1d",
+ "value": " 100/100 [00:00<00:00, 125.88 examples/s]"
+ }
+ },
+ "45b3259e3cac4de8bd19d12f07de2adb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "45c1d5b0df0e420a87f791dd4cf0e425": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4a0426a353ca41cba39d4dfeba925451": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4e19e78059b842a5832ccae2f765a30c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_32cff795f8bc490dbf63ed130e1f581f",
+ "placeholder": "",
+ "style": "IPY_MODEL_4a0426a353ca41cba39d4dfeba925451",
+ "value": "Resolving data files: 100%"
+ }
+ },
+ "4f68a26f64e844c7be21cc180eb6c1a2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_341dca5ac74348dd9b5a347e38fa0b40",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_8ba6fd1bf16a4680b8a8c9c55ecf23e7",
+ "value": 18
+ }
+ },
+ "51a6d3c97480476e8c22d9ad670bdc47": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "53ee8f5e8b7d4076bdb0167baf2e5729": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "58b932a03b2c4aa4891d541f186244b9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "5d1fbd3c62d94df7befdefc451221414": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_8ad6abb48f38469f9d399eea8f5e5b70",
+ "IPY_MODEL_6cea0da24cf54811a43168c606759bab",
+ "IPY_MODEL_eb8c88f5c06c49fe9099371b3cf112ae"
+ ],
+ "layout": "IPY_MODEL_89a1354722e640758978befc06ed4a78"
+ }
+ },
+ "64539b4212fe4d989976f56369bb746b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "67b9a3505ae644dbb3c4fc14781a2731": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_53ee8f5e8b7d4076bdb0167baf2e5729",
+ "max": 100,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_d70fd9035f9b4d82892fae34c28c46d5",
+ "value": 100
+ }
+ },
+ "696e82ec6a174974a90d5abc7c101ee7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6cea0da24cf54811a43168c606759bab": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dade882aca304a31b693a2c58807d825",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_02fc530028ea4d538b7f6b48463ae700",
+ "value": 18
+ }
+ },
+ "72eca1e2871b458abd3383d9711215a2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_058b2b9959b84b6f9f5d3862ef53d029",
+ "IPY_MODEL_85d4879bd7d64766905db34cef052fed",
+ "IPY_MODEL_44f189b81bbd48ca8cb146ead641d2b5"
+ ],
+ "layout": "IPY_MODEL_f89c5c949e984361bce7f97d86d2a2e5"
+ }
+ },
+ "734b6d3e3406403293c4bc955a643528": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dc3b2edc3f5d480a93b57b15b4444608",
+ "placeholder": "",
+ "style": "IPY_MODEL_7967d420aff1414e9fe53eb04c928eb4",
+ "value": "Map: 100%"
+ }
+ },
+ "7517ce80636040e29665a9353afab183": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bb078c8c1f6a48359dc654d91ece684d",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_9b9322336b564a409086955ebda07fc3",
+ "value": 18
+ }
+ },
+ "769b40273bab41af8eb66e494b613241": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dc85f5e365f4488fa185d0ae35fde806",
+ "placeholder": "",
+ "style": "IPY_MODEL_51a6d3c97480476e8c22d9ad670bdc47",
+ "value": " 18/18 [00:00<00:00, 1567.70it/s]"
+ }
+ },
+ "7807f312425b4f4d9249aa1ac77d7461": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "793f49f397b54daab63194cee8d04256": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7967d420aff1414e9fe53eb04c928eb4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7e11cccce8be49008f8db3a0c3ea603d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7e74d789c82747e0b5066a00b9e36c1d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "82c6c2752a0746f3935e069c0f8811d6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "85d4879bd7d64766905db34cef052fed": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0058ed544fed4272848a891a68b9adc0",
+ "max": 100,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_33fb10908c23457aa4796626102fc8c5",
+ "value": 100
+ }
+ },
+ "89a1354722e640758978befc06ed4a78": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8a195771bdc0462e8f9fbb60eb9141b1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "8a8d3a006ee24c4393d7c2f2d040ce52": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "8ad6abb48f38469f9d399eea8f5e5b70": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_39d3b72ab6214bcf9b0bb6b6294e957c",
+ "placeholder": "",
+ "style": "IPY_MODEL_696e82ec6a174974a90d5abc7c101ee7",
+ "value": "Resolving data files: 100%"
+ }
+ },
+ "8ba6fd1bf16a4680b8a8c9c55ecf23e7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "8d153f070a8d4ad1b32996a9fd82beda": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "8f88a5b04723482ea430679e504c65f9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "912164947c5847908424f3e60c5adb64": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ff108c92fb5547869ee545cf9a094b07",
+ "placeholder": "",
+ "style": "IPY_MODEL_2c5564fb033346afbe7692a24a52b302",
+ "value": "Loading dataset shards: 100%"
+ }
+ },
+ "9344b22940c64654a82bb2ce06530e30": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_793f49f397b54daab63194cee8d04256",
+ "placeholder": "",
+ "style": "IPY_MODEL_fa79cfa23f3a430dab69a59d93383cd0",
+ "value": "Resolving data files: 100%"
+ }
+ },
+ "963c0aa5620b4ea8b5a903894646121c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9a5fd3a68fd1445f92bea51a7fec3e6b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9b9322336b564a409086955ebda07fc3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "9bceb9eddb2147c1abbf3391c70e6784": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9ed49f1a099846a3a65cd6608bafb0e4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "9f91f7ce62e243f59d72e5ba36f97b8f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_af0233735d744b7e838f50f52c9d6cbe",
+ "placeholder": "",
+ "style": "IPY_MODEL_8a8d3a006ee24c4393d7c2f2d040ce52",
+ "value": "Loading dataset shards: 100%"
+ }
+ },
+ "a419499622cd4374937423a79677298f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b93514308ae44afbb1a0511f5f9c6ddf",
+ "placeholder": "",
+ "style": "IPY_MODEL_58b932a03b2c4aa4891d541f186244b9",
+ "value": " 18/18 [00:00<00:00, 1458.49it/s]"
+ }
+ },
+ "ada78aafba3f47ab8eb45cf3c83a6805": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "af0096de28414303ba5324f4087cd92e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "af0233735d744b7e838f50f52c9d6cbe": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b7e7896aeac74b6eae27de0677100e57": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b8b277831f1a45109b3a4a3565fbdb9d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_9f91f7ce62e243f59d72e5ba36f97b8f",
+ "IPY_MODEL_1634ba52355b4681a913039666926f85",
+ "IPY_MODEL_217ca5cd404d4756a399fba3aa4fbc15"
+ ],
+ "layout": "IPY_MODEL_bc6d92cb8837428bb7038d75e6af604e"
+ }
+ },
+ "b93514308ae44afbb1a0511f5f9c6ddf": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bb078c8c1f6a48359dc654d91ece684d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bb1156b7d349440d9cc8a2f0328465a7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_23a71f8847e647daba35e495706fc846",
+ "IPY_MODEL_3f7afd4bd28842cbb73e62c155667030",
+ "IPY_MODEL_a419499622cd4374937423a79677298f"
+ ],
+ "layout": "IPY_MODEL_64539b4212fe4d989976f56369bb746b"
+ }
+ },
+ "bc6d92cb8837428bb7038d75e6af604e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bd087d0aa3214c5dbecc9b0bd4d976df": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "be6162f66e594d3ebd8c53ebab3bbfa6": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_963c0aa5620b4ea8b5a903894646121c",
+ "placeholder": "",
+ "style": "IPY_MODEL_31a203cdd2f54cda8a05214844888156",
+ "value": " 100/100 [00:00<00:00, 5440.44 examples/s]"
+ }
+ },
+ "c4d39c87c16c4961b942d896742ff7ce": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_af0096de28414303ba5324f4087cd92e",
+ "placeholder": "",
+ "style": "IPY_MODEL_0f55ae30c2704632941cca4727c1c4f2",
+ "value": " 100/100 [00:01<00:00, 113.55 examples/s]"
+ }
+ },
+ "c9cfd66b68a1437d946c83163fa877df": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_45b3259e3cac4de8bd19d12f07de2adb",
+ "placeholder": "",
+ "style": "IPY_MODEL_b7e7896aeac74b6eae27de0677100e57",
+ "value": " 18/18 [00:00<00:00, 1.32it/s]"
+ }
+ },
+ "cccd970273ae43d2a6e60ac421bdc882": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d70fd9035f9b4d82892fae34c28c46d5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "d8e7ea9552a84b8284b31d77090b54af": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "da5cd094aaae45f4a0ca051ad5babd78": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "dade882aca304a31b693a2c58807d825": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dc3b2edc3f5d480a93b57b15b4444608": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dc85f5e365f4488fa185d0ae35fde806": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e14b9d980a1a41fb9e81385cb0f73d3a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9bceb9eddb2147c1abbf3391c70e6784",
+ "placeholder": "",
+ "style": "IPY_MODEL_8a195771bdc0462e8f9fbb60eb9141b1",
+ "value": " 18/18 [00:35<00:00, 1.20it/s]"
+ }
+ },
+ "e257e4a2bfdb48038102173d397ab2e4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_82c6c2752a0746f3935e069c0f8811d6",
+ "placeholder": "",
+ "style": "IPY_MODEL_1850ab17bafd4a43b5ab5899d1875a40",
+ "value": "Map (num_proc=2): 100%"
+ }
+ },
+ "e3bd7f85ce194cd4b697c2eb82038658": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_734b6d3e3406403293c4bc955a643528",
+ "IPY_MODEL_0005f2d9fe1e4cc98ea58b0c2868b433",
+ "IPY_MODEL_be6162f66e594d3ebd8c53ebab3bbfa6"
+ ],
+ "layout": "IPY_MODEL_7e11cccce8be49008f8db3a0c3ea603d"
+ }
+ },
+ "e5880b946aae4b84a94226a5d6acaf45": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e903140c8c794c48b231924d3975b7a6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "eb8c88f5c06c49fe9099371b3cf112ae": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_00eea4b0c6e44c62900ea8e7d919efe9",
+ "placeholder": "",
+ "style": "IPY_MODEL_fe17bedb5ef04d8b9e064fa1e0d75185",
+ "value": " 18/18 [00:00<00:00, 1.42it/s]"
+ }
+ },
+ "eff94d2d010e4b4f93a6dfcb61103a52": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f89c5c949e984361bce7f97d86d2a2e5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fa79cfa23f3a430dab69a59d93383cd0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "fe17bedb5ef04d8b9e064fa1e0d75185": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ff108c92fb5547869ee545cf9a094b07": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ffa74977e7464cebb16d3cf8ee976d51": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e257e4a2bfdb48038102173d397ab2e4",
+ "IPY_MODEL_67b9a3505ae644dbb3c4fc14781a2731",
+ "IPY_MODEL_c4d39c87c16c4961b942d896742ff7ce"
+ ],
+ "layout": "IPY_MODEL_e5880b946aae4b84a94226a5d6acaf45"
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/ai-medical-chatbot-master/6-FineTunning/Howto-Finetuning-Llama3-with-unsloth.ipynb b/ai-medical-chatbot-master/6-FineTunning/Howto-Finetuning-Llama3-with-unsloth.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..3fc2921b7d88a28dc56b3772fcc8f5e537d3af75
--- /dev/null
+++ b/ai-medical-chatbot-master/6-FineTunning/Howto-Finetuning-Llama3-with-unsloth.ipynb
@@ -0,0 +1,6531 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# How to FineTune Llama 3 with SFTTrainer and Unsloth\n",
+ "Hello everyone, today we are going to show how we can Fine Tune Llama 3 with SFTTrainer and Unsloth\n",
+ "First we are going to perform a simmple Fine Tunning by using SFTTrainer\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 1 - Installation of Pytorch\n",
+ "The first step is install pythorch v 2.2.1 with Cuda 12.1 "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: pip in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (23.3)\n",
+ "Collecting pip\n",
+ " Downloading pip-24.0-py3-none-any.whl.metadata (3.6 kB)\n",
+ "Downloading pip-24.0-py3-none-any.whl (2.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m38.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
+ "\u001b[?25hInstalling collected packages: pip\n",
+ " Attempting uninstall: pip\n",
+ " Found existing installation: pip 23.3\n",
+ " Uninstalling pip-23.3:\n",
+ " Successfully uninstalled pip-23.3\n",
+ "Successfully installed pip-24.0\n",
+ "Looking in indexes: https://download.pytorch.org/whl/cu121\n",
+ "Collecting torch==2.2.1\n",
+ " Downloading https://download.pytorch.org/whl/cu121/torch-2.2.1%2Bcu121-cp310-cp310-linux_x86_64.whl (757.3 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.3/757.3 MB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: torchvision in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (0.15.2)\n",
+ "Collecting torchaudio\n",
+ " Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.3.0%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m86.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n",
+ "\u001b[?25hCollecting xformers\n",
+ " Downloading https://download.pytorch.org/whl/cu121/xformers-0.0.26.post1-cp310-cp310-manylinux2014_x86_64.whl (222.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m222.7/222.7 MB\u001b[0m \u001b[31m33.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: filelock in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from torch==2.2.1) (3.9.0)\n",
+ "Collecting typing-extensions>=4.8.0 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/typing_extensions-4.9.0-py3-none-any.whl (32 kB)\n",
+ "Requirement already satisfied: sympy in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from torch==2.2.1) (1.12)\n",
+ "Requirement already satisfied: networkx in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from torch==2.2.1) (2.8.4)\n",
+ "Requirement already satisfied: jinja2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from torch==2.2.1) (3.1.3)\n",
+ "Requirement already satisfied: fsspec in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from torch==2.2.1) (2022.11.0)\n",
+ "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m97.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m50.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m120.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m50.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m76.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m36.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-nccl-cu12==2.19.3 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m166.0/166.0 MB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting triton==2.2.0 (from torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (167.9 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m167.9/167.9 MB\u001b[0m \u001b[31m19.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.1)\n",
+ " Downloading https://download.pytorch.org/whl/cu121/nvidia_nvjitlink_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (19.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.8/19.8 MB\u001b[0m \u001b[31m125.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: numpy in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from torchvision) (1.23.5)\n",
+ "Requirement already satisfied: requests in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from torchvision) (2.31.0)\n",
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from torchvision) (10.3.0)\n",
+ "INFO: pip is looking at multiple versions of torchaudio to determine which version is compatible with other requirements. This could take a while.\n",
+ "Collecting torchaudio\n",
+ " Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.2.2%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m129.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.2.1%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m87.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n",
+ "\u001b[?25hINFO: pip is looking at multiple versions of xformers to determine which version is compatible with other requirements. This could take a while.\n",
+ "Collecting xformers\n",
+ " Downloading https://download.pytorch.org/whl/cu121/xformers-0.0.26-cp310-cp310-manylinux2014_x86_64.whl (222.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m222.6/222.6 MB\u001b[0m \u001b[31m11.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25h Downloading https://download.pytorch.org/whl/cu121/xformers-0.0.25.post1-cp310-cp310-manylinux2014_x86_64.whl (222.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m222.5/222.5 MB\u001b[0m \u001b[31m34.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25h Downloading https://download.pytorch.org/whl/cu121/xformers-0.0.25-cp310-cp310-manylinux2014_x86_64.whl (222.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m222.5/222.5 MB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: MarkupSafe>=2.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from jinja2->torch==2.2.1) (2.1.1)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests->torchvision) (2.0.4)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests->torchvision) (3.7)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests->torchvision) (1.26.18)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests->torchvision) (2024.2.2)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from sympy->torch==2.2.1) (1.3.0)\n",
+ "Installing collected packages: typing-extensions, triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, xformers, torchaudio\n",
+ " Attempting uninstall: typing-extensions\n",
+ " Found existing installation: typing_extensions 4.4.0\n",
+ " Uninstalling typing_extensions-4.4.0:\n",
+ " Successfully uninstalled typing_extensions-4.4.0\n",
+ " Attempting uninstall: torch\n",
+ " Found existing installation: torch 2.0.1\n",
+ " Uninstalling torch-2.0.1:\n",
+ " Successfully uninstalled torch-2.0.1\n",
+ "Successfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.1.105 nvidia-nvtx-cu12-12.1.105 torch-2.2.1+cu121 torchaudio-2.2.1+cu121 triton-2.2.0 typing-extensions-4.9.0 xformers-0.0.25\n"
+ ]
+ }
+ ],
+ "source": [
+ "!python -m pip install --upgrade pip\n",
+ "!pip3 install torch==2.2.1 torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 3 - Installation of Uslotch packages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)\n",
+ " Cloning https://github.com/unslothai/unsloth.git to /tmp/wsuser/pip-install-8a93kdi0/unsloth_56c62c14bb3f4be29d884342054fdd22\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/wsuser/pip-install-8a93kdi0/unsloth_56c62c14bb3f4be29d884342054fdd22\n",
+ " Resolved https://github.com/unslothai/unsloth.git to commit 4211cc01409e3ced4f7abebaf68e244193b46e2c\n",
+ " Installing build dependencies ... \u001b[?25ldone\n",
+ "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
+ "\u001b[?25h Installing backend dependencies ... \u001b[?25ldone\n",
+ "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
+ "\u001b[?25hRequirement already satisfied: tyro in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.8.3)\n",
+ "Requirement already satisfied: transformers>=4.38.2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (4.40.2)\n",
+ "Requirement already satisfied: datasets>=2.16.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2.19.1)\n",
+ "Requirement already satisfied: sentencepiece in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.1.97)\n",
+ "Requirement already satisfied: tqdm in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (4.65.0)\n",
+ "Requirement already satisfied: psutil in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (5.9.0)\n",
+ "Requirement already satisfied: wheel>=0.42.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.43.0)\n",
+ "Requirement already satisfied: numpy in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (1.23.5)\n",
+ "Requirement already satisfied: protobuf<4.0.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (3.20.3)\n",
+ "Requirement already satisfied: filelock in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (3.9.0)\n",
+ "Requirement already satisfied: pyarrow>=12.0.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (16.0.0)\n",
+ "Requirement already satisfied: pyarrow-hotfix in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.6)\n",
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.3.8)\n",
+ "Requirement already satisfied: pandas in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (1.5.3)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2.31.0)\n",
+ "Requirement already satisfied: xxhash in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (3.4.1)\n",
+ "Requirement already satisfied: multiprocess in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.70.16)\n",
+ "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2024.3.1)\n",
+ "Requirement already satisfied: aiohttp in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (3.9.3)\n",
+ "Requirement already satisfied: huggingface-hub>=0.21.2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.23.0)\n",
+ "Requirement already satisfied: packaging in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (23.0)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (6.0)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from transformers>=4.38.2->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2022.3.15)\n",
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from transformers>=4.38.2->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.19.1)\n",
+ "Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from transformers>=4.38.2->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.4.3)\n",
+ "Requirement already satisfied: docstring-parser>=0.14.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from tyro->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.16)\n",
+ "Requirement already satisfied: typing-extensions>=4.7.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from tyro->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (4.9.0)\n",
+ "Requirement already satisfied: rich>=11.1.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from tyro->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (13.7.1)\n",
+ "Requirement already satisfied: shtab>=1.5.6 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from tyro->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (1.7.1)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (1.2.0)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (23.1.0)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (1.3.3)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (6.0.2)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (1.8.1)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (4.0.2)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests>=2.19.0->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2.0.4)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests>=2.19.0->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (3.7)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests>=2.19.0->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (1.26.18)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests>=2.19.0->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2024.2.2)\n",
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from rich>=11.1.0->tyro->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (3.0.0)\n",
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from rich>=11.1.0->tyro->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2.15.1)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from pandas->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from pandas->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (2022.7)\n",
+ "Requirement already satisfied: mdurl~=0.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (0.1.2)\n",
+ "Requirement already satisfied: six>=1.5 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas->datasets>=2.16.0->unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git) (1.16.0)\n",
+ "Requirement already satisfied: trl in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (0.8.6)\n",
+ "Requirement already satisfied: peft in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (0.10.0)\n",
+ "Requirement already satisfied: accelerate in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (0.30.0)\n",
+ "Requirement already satisfied: bitsandbytes in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (0.43.1)\n",
+ "Requirement already satisfied: datasets in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (2.19.1)\n",
+ "Requirement already satisfied: filelock in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (3.9.0)\n",
+ "Requirement already satisfied: numpy>=1.17 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (1.23.5)\n",
+ "Requirement already satisfied: pyarrow>=12.0.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (16.0.0)\n",
+ "Requirement already satisfied: pyarrow-hotfix in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (0.6)\n",
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (0.3.8)\n",
+ "Requirement already satisfied: pandas in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (1.5.3)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (2.31.0)\n",
+ "Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (4.65.0)\n",
+ "Requirement already satisfied: xxhash in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (3.4.1)\n",
+ "Requirement already satisfied: multiprocess in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (0.70.16)\n",
+ "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2024.3.1)\n",
+ "Requirement already satisfied: aiohttp in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (3.9.3)\n",
+ "Requirement already satisfied: huggingface-hub>=0.21.2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (0.23.0)\n",
+ "Requirement already satisfied: packaging in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (23.0)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from datasets) (6.0)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets) (1.2.0)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets) (23.1.0)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.3)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.2)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets) (1.8.1)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.2)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from huggingface-hub>=0.21.2->datasets) (4.9.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2.0.4)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.7)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (1.26.18)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2024.2.2)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from pandas->datasets) (2022.7)\n",
+ "Requirement already satisfied: six>=1.5 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
+ "Requirement already satisfied: hyperopt in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (0.2.5)\n",
+ "Requirement already satisfied: numpy in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from hyperopt) (1.23.5)\n",
+ "Requirement already satisfied: scipy in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from hyperopt) (1.10.1)\n",
+ "Requirement already satisfied: six in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from hyperopt) (1.16.0)\n",
+ "Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from hyperopt) (2.8.4)\n",
+ "Requirement already satisfied: future in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from hyperopt) (0.18.3)\n",
+ "Requirement already satisfied: tqdm in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from hyperopt) (4.65.0)\n",
+ "Requirement already satisfied: cloudpickle in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from hyperopt) (2.2.1)\n",
+ "Requirement already satisfied: optuna in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (3.6.1)\n",
+ "Requirement already satisfied: alembic>=1.5.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from optuna) (1.13.1)\n",
+ "Requirement already satisfied: colorlog in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from optuna) (6.8.2)\n",
+ "Requirement already satisfied: numpy in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from optuna) (1.23.5)\n",
+ "Requirement already satisfied: packaging>=20.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from optuna) (23.0)\n",
+ "Requirement already satisfied: sqlalchemy>=1.3.0 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from optuna) (1.4.39)\n",
+ "Requirement already satisfied: tqdm in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from optuna) (4.65.0)\n",
+ "Requirement already satisfied: PyYAML in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from optuna) (6.0)\n",
+ "Requirement already satisfied: Mako in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from alembic>=1.5.0->optuna) (1.3.3)\n",
+ "Requirement already satisfied: typing-extensions>=4 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from alembic>=1.5.0->optuna) (4.9.0)\n",
+ "Requirement already satisfied: greenlet!=0.4.17 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from sqlalchemy>=1.3.0->optuna) (2.0.1)\n",
+ "Requirement already satisfied: MarkupSafe>=0.9.2 in /opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages (from Mako->alembic>=1.5.0->optuna) (2.1.1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "major_version, minor_version = torch.cuda.get_device_capability()\n",
+ "# Must install separately since Colab has torch 2.2.1, which breaks packages\n",
+ "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
+ "if major_version >= 8:\n",
+ " # Use this for new GPUs like Ampere, Hopper GPUs (RTX 30xx, RTX 40xx, A100, H100, L40)\n",
+ " !pip install --no-deps packaging ninja einops flash-attn xformers trl peft \\\n",
+ " accelerate bitsandbytes\n",
+ "else:\n",
+ " # Use this for older GPUs (V100, Tesla T4, RTX 20xx)\n",
+ " !pip install --no-deps trl peft accelerate bitsandbytes\n",
+ "!pip install datasets\n",
+ "!pip install hyperopt\n",
+ "!pip install optuna \n",
+ "pass"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 4 - Analysis of our infrastructure\n",
+ "In ordering to perform any training it is important to know our system in order to take the full advantage of the system."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dH4JvbO9oiHE",
+ "outputId": "399bc210-c095-4807-900f-6b4cf2fe133f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unable to find python bindings at /usr/local/dcgm/bindings/python3. No data will be captured.\n",
+ "xFormers 0.0.25\n",
+ "memory_efficient_attention.ckF: unavailable\n",
+ "memory_efficient_attention.ckB: unavailable\n",
+ "memory_efficient_attention.ck_decoderF: unavailable\n",
+ "memory_efficient_attention.ck_splitKF: unavailable\n",
+ "memory_efficient_attention.cutlassF: available\n",
+ "memory_efficient_attention.cutlassB: available\n",
+ "memory_efficient_attention.decoderF: available\n",
+ "memory_efficient_attention.flshattF@v2.5.6: available\n",
+ "memory_efficient_attention.flshattB@v2.5.6: available\n",
+ "memory_efficient_attention.smallkF: available\n",
+ "memory_efficient_attention.smallkB: available\n",
+ "memory_efficient_attention.triton_splitKF: unavailable\n",
+ "indexing.scaled_index_addF: unavailable\n",
+ "indexing.scaled_index_addB: unavailable\n",
+ "indexing.index_select: unavailable\n",
+ "sequence_parallel_fused.write_values: available\n",
+ "sequence_parallel_fused.wait_values: available\n",
+ "sequence_parallel_fused.cuda_memset_32b_async: available\n",
+ "sp24.sparse24_sparsify_both_ways: available\n",
+ "sp24.sparse24_apply: available\n",
+ "sp24.sparse24_apply_dense_output: available\n",
+ "sp24._sparse24_gemm: available\n",
+ "sp24._cslt_sparse_mm@0.4.0: available\n",
+ "swiglu.dual_gemm_silu: available\n",
+ "swiglu.gemm_fused_operand_sum: available\n",
+ "swiglu.fused.p.cpp: available\n",
+ "is_triton_available: False\n",
+ "pytorch.version: 2.2.1+cu121\n",
+ "pytorch.cuda: available\n",
+ "gpu.compute_capability: 7.0\n",
+ "gpu.name: Tesla V100-PCIE-16GB\n",
+ "dcgm_profiler: unavailable\n",
+ "build.info: available\n",
+ "build.cuda_version: 1201\n",
+ "build.hip_version: None\n",
+ "build.python_version: 3.10.13\n",
+ "build.torch_version: 2.2.1+cu121\n",
+ "build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX 9.0\n",
+ "build.env.PYTORCH_ROCM_ARCH: None\n",
+ "build.env.XFORMERS_BUILD_TYPE: Release\n",
+ "build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None\n",
+ "build.env.NVCC_FLAGS: None\n",
+ "build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.25\n",
+ "build.nvcc_version: 12.1.66\n",
+ "source.privacy: open source\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ "++++++++++++++++++ BUG REPORT INFORMATION ++++++++++++++++++\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ "++++++++++++++++++++++++++ OTHER +++++++++++++++++++++++++++\n",
+ "CUDA specs: CUDASpecs(highest_compute_capability=(7, 0), cuda_version_string='121', cuda_version_tuple=(12, 1))\n",
+ "PyTorch settings found: CUDA_VERSION=121, Highest Compute Capability: (7, 0).\n",
+ "To manually override the PyTorch CUDA version please see: https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx\n",
+ "WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!\n",
+ "If you run into issues with 8-bit matmul, you can try 4-bit quantization:\n",
+ "https://huggingface.co/blog/4bit-transformers-bitsandbytes\n",
+ "The directory listed in your path is found to be non-existent: /usr/local/nvidia/lib\n",
+ "The directory listed in your path is found to be non-existent: //private.runtime.dataplatform.cloud.ibm.com\n",
+ "The directory listed in your path is found to be non-existent: /home/wsuser/jars/*\n",
+ "The directory listed in your path is found to be non-existent: /opt/jdbc/*\n",
+ "The directory listed in your path is found to be non-existent: bluemix/prod\n",
+ "The directory listed in your path is found to be non-existent: //api.dataplatform.cloud.ibm.com\n",
+ "The directory listed in your path is found to be non-existent: //matplotlib_inline.backend_inline\n",
+ "The directory listed in your path is found to be non-existent: --xla_gpu_cuda_data_dir=/opt/conda/envs/Python-RT23.1-CUDA\n",
+ "CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ "++++++++++++++++++++++ DEBUG INFO END ++++++++++++++++++++++\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ "Checking that the library is importable and CUDA is callable...\n",
+ "SUCCESS!\n",
+ "Installation was successful!\n",
+ "Thu May 9 19:59:13 2024 \n",
+ "+---------------------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
+ "|-----------------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|=========================================+======================+======================|\n",
+ "| 0 Tesla V100-PCIE-16GB Off | 00000000:AF:00.0 Off | 0 |\n",
+ "| N/A 33C P0 40W / 250W | 4MiB / 16384MiB | 0% Default |\n",
+ "| | | N/A |\n",
+ "+-----------------------------------------+----------------------+----------------------+\n",
+ "| 1 Tesla V100-PCIE-16GB Off | 00000000:D8:00.0 Off | 0 |\n",
+ "| N/A 33C P0 26W / 250W | 4MiB / 16384MiB | 0% Default |\n",
+ "| | | N/A |\n",
+ "+-----------------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+---------------------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=======================================================================================|\n",
+ "| No running processes found |\n",
+ "+---------------------------------------------------------------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "!python -m xformers.info\n",
+ "!python -m bitsandbytes\n",
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 5 Login to Hugging Face"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
+ "Token is valid (permission: write).\n",
+ "Your token has been saved to /home/wsuser/.cache/huggingface/token\n",
+ "Login successful\n"
+ ]
+ }
+ ],
+ "source": [
+ "token=\"hf_\"\n",
+ "from huggingface_hub import login, logout\n",
+ "login(token) # non-blocking login"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 5 Simple Fine Tunning Method\n",
+ "\n",
+ "First let us show the simplest method that is given by SFTTrainer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "25db645610ac40e6a8a647896dec0f16",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Resolving data files: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a8d6ac36539a4c12a5dc6eafbc58d9ca",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Resolving data files: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "53fa0f7afe9849c99fa1761aafd363c1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading dataset shards: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "da22a855c9524144870d16abcf850047",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer_config.json: 0%| | 0.00/51.0k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6770b2824fa64dc1a3dd6279683e6915",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer.json: 0%| | 0.00/9.09M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cd128b3e68ac4bf6a87d812e1c0c248e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "special_tokens_map.json: 0%| | 0.00/73.0 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0fb4fe0fc5384b3182762db3f8c42b9e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0%| | 0.00/654 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7f9cfeabcf424ccc98a2b9d69d770a0a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors.index.json: 0%| | 0.00/23.9k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8244c9b77e7f4afe8524963a8f7b02e2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading shards: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "690662c1a9244cdd9ea9f1162ce3cd30",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3813d56b85b74450a93e9ade68a807b5",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ebb06c38dd79488ab8eb273359406577",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5251b05267db420d8f14ce05b7d52b81",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "947b7ad7d14f4201902e1a292ec7f513",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1ddcf0ae14934fd79d8926aee0972abe",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "generation_config.json: 0%| | 0.00/187 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7423975680ef41ff9ac362c72b2c0907",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/100 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "max_steps is given, it will override any value given in num_train_epochs\n",
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1/1 00:04, Epoch 0/1]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 2.346700 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=1, training_loss=2.346719980239868, metrics={'train_runtime': 8.2949, 'train_samples_per_second': 0.482, 'train_steps_per_second': 0.121, 'total_flos': 74593973698560.0, 'train_loss': 2.346719980239868, 'epoch': 0.04})"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
+ "from peft import LoraConfig\n",
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "# Load the dataset\n",
+ "dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
+ "dataset = load_dataset(dataset_name, split=\"train\")\n",
+ "# Select the first 1000 rows of the dataset\n",
+ "dataset = dataset.select(range(100))\n",
+ "# Device map\n",
+ "device_map = 'auto' # for PP and running with `python test_sft.py`\n",
+ "# Load the model + tokenizer\n",
+ "model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ "tokenizer.pad_token = tokenizer.eos_token\n",
+ "bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ ")\n",
+ "model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " quantization_config=bnb_config,\n",
+ " trust_remote_code=True,\n",
+ " use_cache=False,\n",
+ " device_map=device_map\n",
+ ")\n",
+ "# PEFT config\n",
+ "lora_alpha = 16\n",
+ "lora_dropout = 0.1\n",
+ "lora_r = 32 # 64\n",
+ "peft_config = LoraConfig(\n",
+ " lora_alpha=lora_alpha,\n",
+ " lora_dropout=lora_dropout,\n",
+ " r=lora_r,\n",
+ " bias=\"none\",\n",
+ " task_type=\"CAUSAL_LM\",\n",
+ " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
+ " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
+ ")\n",
+ "# Args\n",
+ "max_seq_length = 512\n",
+ "output_dir = \"./results\"\n",
+ "per_device_train_batch_size = 2 # reduced batch size to avoid OOM\n",
+ "gradient_accumulation_steps = 2\n",
+ "optim = \"adamw_torch\"\n",
+ "save_steps = 10\n",
+ "logging_steps = 1\n",
+ "learning_rate = 2e-4\n",
+ "max_grad_norm = 0.3\n",
+ "max_steps = 1 \n",
+ "warmup_ratio = 0.1\n",
+ "lr_scheduler_type = \"cosine\"\n",
+ "training_arguments = TrainingArguments(\n",
+ " output_dir=output_dir,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " optim=optim,\n",
+ " save_steps=save_steps,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " fp16=True,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " max_steps=max_steps,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=True,\n",
+ " lr_scheduler_type=lr_scheduler_type,\n",
+ " gradient_checkpointing=True, # gradient checkpointing\n",
+ " #report_to=\"wandb\",\n",
+ ")\n",
+ "# Trainer\n",
+ "trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " peft_config=peft_config,\n",
+ " dataset_text_field=\"context\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " tokenizer=tokenizer,\n",
+ " args=training_arguments,\n",
+ ")\n",
+ "\n",
+ "# Train :)\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
+ "Token is valid (permission: write).\n",
+ "Your token has been saved to /home/wsuser/.cache/huggingface/token\n",
+ "Login successful\n"
+ ]
+ }
+ ],
+ "source": [
+ "token=\"\"\n",
+ "from huggingface_hub import login, logout\n",
+ "login(token) # non-blocking login"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Multiple GPUS"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n",
+ "Some weights of BartForCausalLM were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['decoder.embed_tokens.weight', 'lm_head.weight']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
+ "from datasets import load_dataset\n",
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "import os\n",
+ "import socket\n",
+ "\n",
+ "# Distributed training setup (assuming all GPUs are available on a single machine)\n",
+ "def init_distributed(rank, world_size):\n",
+ " \"\"\"Initializes distributed training using `nccl` backend.\"\"\"\n",
+ " if rank == 0:\n",
+ " os.environ[\"MASTER_ADDR\"] = socket.gethostname() # Set MASTER_ADDR using rank 0's hostname\n",
+ " else:\n",
+ " # Wait a bit to ensure MASTER_ADDR is set before other ranks try to use it\n",
+ " import time\n",
+ " time.sleep(5)\n",
+ " os.environ[\"MASTER_PORT\"] = \"12345\" # Set MASTER_PORT environment variable\n",
+ " os.environ[\"RANK\"] = str(rank) # Set RANK environment variable\n",
+ " os.environ[\"WORLD_SIZE\"] = str(world_size) # Set WORLD_SIZE environment variable\n",
+ " torch.distributed.init_process_group(backend='nccl', init_method='env://')\n",
+ "\n",
+ "# Cleanup after training\n",
+ "def cleanup_distributed():\n",
+ " if torch.distributed.is_initialized():\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "# Model and tokenizer selection\n",
+ "model_name = \"facebook/bart-base\" # Replace with your desired model\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
+ "\n",
+ "# Dataset loading (replace with your dataset and field names)\n",
+ "dataset = load_dataset(\"glue\", \"mnli\", split=\"train\")\n",
+ "text_field = \"premise\" # Assuming premise is the field containing text for prediction\n",
+ "\n",
+ "# Training arguments (adjust hyperparameters as needed)\n",
+ "training_args = TrainingArguments(\n",
+ " output_dir=\"./results\",\n",
+ " per_device_train_batch_size=2, # Adjust based on GPU memory (might need to adjust)\n",
+ " save_steps=500,\n",
+ " save_total_limit=2,\n",
+ " num_train_epochs=3, # Adjust training time as needed\n",
+ ")\n",
+ "\n",
+ "world_size = torch.cuda.device_count()\n",
+ "if world_size > 1:\n",
+ " # Initialize distributed training\n",
+ " init_distributed(rank=0, world_size=world_size) # Rank is assumed to be 0 here\n",
+ "\n",
+ " # Wrap model in DDP for distributed training\n",
+ " model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])\n",
+ "\n",
+ " # Create SFTTrainer with distributed settings\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " args=training_args, # Pass training_args as 'args' instead of 'training_args'\n",
+ " train_dataset=dataset,\n",
+ " dataset_text_field=text_field,\n",
+ " compute_metrics=None, # You can define your custom metrics here\n",
+ " )\n",
+ " print(\"Trainer For distributed training loaded\")\n",
+ "else:\n",
+ " # For single-GPU training\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " args=training_args, # Pass training_args as 'args' instead of 'training_args'\n",
+ " train_dataset=dataset,\n",
+ " dataset_text_field=text_field,\n",
+ " compute_metrics=None, # You can define your custom metrics here\n",
+ " )\n",
+ " print(\"Trainer For single-GPU loaded\")\n",
+ "\n",
+ "# Start training\n",
+ "trainer.train()\n",
+ "\n",
+ "# Cleanup after training\n",
+ "cleanup_distributed()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n",
+ "Some weights of BartForCausalLM were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['decoder.embed_tokens.weight', 'lm_head.weight']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n",
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:246: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n",
+ " warnings.warn(\n",
+ "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Trainer For single-GPU loaded\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 329/294528 00:36 < 9:12:58, 8.87 it/s, Epoch 0.00/3]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[1], line 66\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrainer For single-GPU loaded\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 65\u001b[0m \u001b[38;5;66;03m# Start training\u001b[39;00m\n\u001b[0;32m---> 66\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;66;03m# Cleanup after training\u001b[39;00m\n\u001b[1;32m 69\u001b[0m cleanup_distributed()\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:361\u001b[0m, in \u001b[0;36mSFTTrainer.train\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneftune_noise_alpha \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer_supports_neftune:\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trl_activate_neftune(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel)\n\u001b[0;32m--> 361\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 363\u001b[0m \u001b[38;5;66;03m# After training we make sure to retrieve back the original forward pass method\u001b[39;00m\n\u001b[1;32m 364\u001b[0m \u001b[38;5;66;03m# for the embedding layer by removing the forward post hook.\u001b[39;00m\n\u001b[1;32m 365\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneftune_noise_alpha \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer_supports_neftune:\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/transformers/trainer.py:1859\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1857\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1859\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1860\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1861\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1862\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1863\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1864\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/transformers/trainer.py:2203\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2202\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2203\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2206\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2207\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2208\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2209\u001b[0m ):\n\u001b[1;32m 2210\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2211\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/transformers/trainer.py:3138\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 3137\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3138\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3140\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 3141\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/transformers/trainer.py:3161\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 3159\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3160\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 3161\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m 3163\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m 3164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:185\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodule_kwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 184\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 185\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:200\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\u001b[38;5;28mself\u001b[39m, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Any]:\n\u001b[0;32m--> 200\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:100\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 98\u001b[0m thread\u001b[38;5;241m.\u001b[39mstart()\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m thread \u001b[38;5;129;01min\u001b[39;00m threads:\n\u001b[0;32m--> 100\u001b[0m \u001b[43mthread\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 102\u001b[0m _worker(\u001b[38;5;241m0\u001b[39m, modules[\u001b[38;5;241m0\u001b[39m], inputs[\u001b[38;5;241m0\u001b[39m], kwargs_tup[\u001b[38;5;241m0\u001b[39m], devices[\u001b[38;5;241m0\u001b[39m], streams[\u001b[38;5;241m0\u001b[39m])\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/threading.py:1096\u001b[0m, in \u001b[0;36mThread.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1093\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot join current thread\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1095\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1096\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wait_for_tstate_lock\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1097\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# the behavior of a negative timeout isn't documented, but\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# historically .join(timeout=x) for x<0 has acted as if timeout=0\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait_for_tstate_lock(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mmax\u001b[39m(timeout, \u001b[38;5;241m0\u001b[39m))\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/threading.py:1116\u001b[0m, in \u001b[0;36mThread._wait_for_tstate_lock\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 1113\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1115\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1116\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mlock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43mblock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 1117\u001b[0m lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 1118\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stop()\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
+ "from datasets import load_dataset\n",
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "import os\n",
+ "\n",
+ "# Distributed training setup (assuming all GPUs are available on a single machine)\n",
+ "def init_distributed():\n",
+ " # Replace with actual hostname or IP if using multiple machines\n",
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ " os.environ[\"MASTER_PORT\"] = \"12345\"\n",
+ " torch.distributed.init_process_group(backend='nccl', world_size=torch.cuda.device_count(), rank=rank)\n",
+ "\n",
+ "def cleanup_distributed():\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "# Model and tokenizer selection\n",
+ "model_name = \"facebook/bart-base\" # Replace with your desired model\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
+ "\n",
+ "# Dataset loading (replace with your dataset and field names)\n",
+ "dataset = load_dataset(\"glue\", \"mnli\", split=\"train\")\n",
+ "text_field = \"premise\" # Assuming premise is the field containing text for prediction\n",
+ "\n",
+ "# Training arguments (adjust hyperparameters as needed)\n",
+ "training_args = TrainingArguments(\n",
+ " output_dir=\"./results\",\n",
+ " per_device_train_batch_size=2, # Adjust based on GPU memory\n",
+ " save_steps=500,\n",
+ " save_total_limit=2,\n",
+ " num_train_epochs=3, # Adjust training time as needed\n",
+ ")\n",
+ "\n",
+ "# Distributed training setup with SFTTrainer\n",
+ "if torch.distributed.is_initialized():\n",
+ " rank = torch.distributed.get_rank()\n",
+ " world_size = torch.distributed.get_world_size()\n",
+ " # Wrap model in DDP for distributed training\n",
+ " model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])\n",
+ "\n",
+ " # Create SFTTrainer with distributed settings\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " args=training_args, # Pass training_args as 'args' instead of 'training_args'\n",
+ " train_dataset=dataset,\n",
+ " dataset_text_field=text_field,\n",
+ " compute_metrics=None, # You can define your custom metrics here\n",
+ " world_size=world_size,\n",
+ " rank=rank,\n",
+ " )\n",
+ " print(f\"Trainer For distributed training loaded on rank {rank}\")\n",
+ "else:\n",
+ " # For single-GPU training\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " args=training_args, # Pass training_args as 'args' instead of 'training_args'\n",
+ " train_dataset=dataset,\n",
+ " dataset_text_field=text_field,\n",
+ " compute_metrics=None, # You can define your custom metrics here\n",
+ " )\n",
+ " print(\"Trainer For single-GPU loaded\")\n",
+ "\n",
+ "# Start training\n",
+ "trainer.train()\n",
+ "\n",
+ "# Cleanup after training\n",
+ "cleanup_distributed()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Traceback (most recent call last):\n",
+ " File \"\", line 1, in \n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/multiprocessing/spawn.py\", line 116, in spawn_main\n",
+ " exitcode = _main(fd, parent_sentinel)\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/multiprocessing/spawn.py\", line 126, in _main\n",
+ " self = reduction.pickle.load(from_parent)\n",
+ "AttributeError: Can't get attribute 'main_worker' on \n",
+ "Traceback (most recent call last):\n",
+ " File \"\", line 1, in \n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/multiprocessing/spawn.py\", line 116, in spawn_main\n",
+ " exitcode = _main(fd, parent_sentinel)\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/multiprocessing/spawn.py\", line 126, in _main\n",
+ " self = reduction.pickle.load(from_parent)\n",
+ "AttributeError: Can't get attribute 'main_worker' on \n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import torch.multiprocessing as mp\n",
+ "from datasets import load_dataset\n",
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
+ "from peft import LoraConfig\n",
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
+ "\n",
+ "\n",
+ "# Distributed training setup\n",
+ "def init_distributed():\n",
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ " os.environ[\"MASTER_PORT\"] = \"12345\"\n",
+ " torch.distributed.init_process_group(backend='nccl', world_size=torch.cuda.device_count(), rank=rank)\n",
+ "\n",
+ "def cleanup_distributed():\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "def main_worker(rank, world_size):\n",
+ " init_distributed()\n",
+ "\n",
+ " # Your model training and fine-tuning code goes here\n",
+ " # Load the dataset\n",
+ " dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
+ " dataset = load_dataset(dataset_name, split=\"train\")\n",
+ " # Select the first 1M rows of the dataset\n",
+ " dataset = dataset.select(range(100))\n",
+ "\n",
+ " # Load the model + tokenizer\n",
+ " model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ " bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ " )\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " quantization_config=bnb_config,\n",
+ " trust_remote_code=True,\n",
+ " use_cache=False,\n",
+ " )\n",
+ "\n",
+ " # Check for available GPUs\n",
+ " device = torch.device(f\"cuda:{rank}\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ " # PEFT config\n",
+ " lora_alpha = 1\n",
+ " lora_dropout = 0.1\n",
+ " lora_r = 32 # 64\n",
+ " peft_config = LoraConfig(\n",
+ " lora_alpha=lora_alpha,\n",
+ " lora_dropout=lora_dropout,\n",
+ " task_type=\"CAUSAL_LM\",\n",
+ " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
+ " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
+ " )\n",
+ "\n",
+ " # Args\n",
+ " max_seq_length = 512\n",
+ " output_dir = \"./results\"\n",
+ " per_device_train_batch_size = 2 # reduced batch size to avoid OOM\n",
+ " gradient_accumulation_steps = 2\n",
+ " optim = \"adamw_torch\"\n",
+ " save_steps = 10\n",
+ " logging_steps = 1\n",
+ " learning_rate = 2e-4\n",
+ " max_grad_norm = 0.3\n",
+ " max_steps = 1 # 300 Approx the size of guanaco at bs 8, ga 2, 2 GPUs.\n",
+ " warmup_ratio = 0.1\n",
+ " lr_scheduler_type = \"cosine\"\n",
+ " training_arguments = TrainingArguments(\n",
+ " output_dir=output_dir,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " optim=optim,\n",
+ " save_steps=save_steps,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " fp16=True,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " max_steps=max_steps,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=True,\n",
+ " lr_scheduler_type=lr_scheduler_type,\n",
+ " gradient_checkpointing=True, # gradient checkpointing\n",
+ " #report_to=\"wandb\",\n",
+ " )\n",
+ "\n",
+ " # Trainer\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " peft_config=peft_config,\n",
+ " dataset_text_field=\"context\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " tokenizer=tokenizer,\n",
+ " args=training_arguments,\n",
+ " )\n",
+ "\n",
+ " # Train :)\n",
+ " trainer.train()\n",
+ " cleanup_distributed()\n",
+ "\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " world_size = torch.cuda.device_count()\n",
+ " mp.set_start_method('spawn') # Add this line to fix the error\n",
+ " processes = []\n",
+ " for rank in range(world_size):\n",
+ " p = mp.Process(target=main_worker, args=(rank, world_size))\n",
+ " p.start()\n",
+ " processes.append(p)\n",
+ " for p in processes:\n",
+ " p.join()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Supervised Finetuning Trainer (SFT Trainer)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def finetune():\n",
+ " from datasets import load_dataset\n",
+ " import torch\n",
+ " from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
+ " from peft import LoraConfig\n",
+ " from trl import SFTTrainer\n",
+ " from transformers import TrainingArguments\n",
+ " from torch.nn.parallel import DistributedDataParallel as DDP\n",
+ " # Load the dataset\n",
+ " dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
+ " dataset = load_dataset(dataset_name, split=\"train\")\n",
+ " # Select the first 1M rows of the dataset\n",
+ " dataset = dataset.select(range(100))\n",
+ " # Load the model + tokenizer\n",
+ " model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ " bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ " )\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " quantization_config=bnb_config,\n",
+ " trust_remote_code=True,\n",
+ " use_cache=False,\n",
+ " )\n",
+ " # Check for available GPUs\n",
+ " if torch.cuda.device_count() > 1:\n",
+ " print(\"Multiple GPUs detected, enabling DataParallel...\")\n",
+ " model = DDP(model) # Wrap the model with DDP\n",
+ " else:\n",
+ " print(\"Using single GPU...\")\n",
+ " # PEFT config\n",
+ " lora_alpha = 16\n",
+ " lora_dropout = 0.1\n",
+ " lora_r = 32 # 64\n",
+ " peft_config = LoraConfig(\n",
+ " lora_alpha=lora_alpha,\n",
+ " lora_dropout=lora_dropout,\n",
+ " r=lora_r,\n",
+ " bias=\"none\",\n",
+ " task_type=\"CAUSAL_LM\",\n",
+ " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
+ " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
+ " )\n",
+ " # Args\n",
+ " max_seq_length = 512\n",
+ " output_dir = \"./results\"\n",
+ " per_device_train_batch_size = 2 # reduced batch size to avoid OOM\n",
+ " gradient_accumulation_steps = 2\n",
+ " optim = \"adamw_torch\"\n",
+ " save_steps = 10\n",
+ " logging_steps = 1\n",
+ " learning_rate = 2e-4\n",
+ " max_grad_norm = 0.3\n",
+ " max_steps = 1 # 300 Approx the size of guanaco at bs 8, ga 2, 2 GPUs.\n",
+ " warmup_ratio = 0.1\n",
+ " lr_scheduler_type = \"cosine\"\n",
+ "\n",
+ " training_arguments = TrainingArguments(\n",
+ " output_dir=output_dir,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " optim=optim,\n",
+ " save_steps=save_steps,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " fp16=True,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " max_steps=max_steps,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=True,\n",
+ " lr_scheduler_type=lr_scheduler_type,\n",
+ " gradient_checkpointing=True, # gradient checkpointing\n",
+ " #report_to=\"wandb\",\n",
+ " )\n",
+ " # Trainer\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " peft_config=peft_config,\n",
+ " dataset_text_field=\"context\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " tokenizer=tokenizer,\n",
+ " args=training_arguments,\n",
+ " )\n",
+ " # Train :)\n",
+ " trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Initializing distributed process group...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d32deaece84b4e2382865810c9c3f1f4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Resolving data files: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6b4dd257b7a447ac98d8d3c926b41a85",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Resolving data files: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "48dea94b34a84052a77c073ede9289e4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading dataset shards: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n",
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "`low_cpu_mem_usage` was None, now set to True since model is quantized.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6b29e51f80e84dcd9192be7755f12d79",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n",
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "`low_cpu_mem_usage` was None, now set to True since model is quantized.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "449bb46e71164a8687ecbdb75000da92",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Multiple GPUs detected, enabling DataParallel...\n",
+ "Multiple GPUs detected, enabling DataParallel...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch\n",
+ "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch\n",
+ "Process Process-3:\n",
+ "Traceback (most recent call last):\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/multiprocessing/process.py\", line 314, in _bootstrap\n",
+ " self.run()\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/multiprocessing/process.py\", line 108, in run\n",
+ " self._target(*self._args, **self._kwargs)\n",
+ " File \"/tmp/wsuser/ipykernel_4766/3134301475.py\", line 21, in main_worker\n",
+ " finetune() # Move model to assigned GPU\n",
+ " File \"/tmp/wsuser/ipykernel_4766/2145007916.py\", line 80, in finetune\n",
+ " trainer = SFTTrainer(\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/trl/trainer/sft_trainer.py\", line 226, in __init__\n",
+ " model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1688, in __getattr__\n",
+ " raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\n",
+ "AttributeError: 'DistributedDataParallel' object has no attribute 'get_input_embeddings'\n",
+ "Process Process-4:\n",
+ "Traceback (most recent call last):\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/multiprocessing/process.py\", line 314, in _bootstrap\n",
+ " self.run()\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/multiprocessing/process.py\", line 108, in run\n",
+ " self._target(*self._args, **self._kwargs)\n",
+ " File \"/tmp/wsuser/ipykernel_4766/3134301475.py\", line 21, in main_worker\n",
+ " finetune() # Move model to assigned GPU\n",
+ " File \"/tmp/wsuser/ipykernel_4766/2145007916.py\", line 80, in finetune\n",
+ " trainer = SFTTrainer(\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/trl/trainer/sft_trainer.py\", line 226, in __init__\n",
+ " model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)\n",
+ " File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1688, in __getattr__\n",
+ " raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\n",
+ "AttributeError: 'DistributedDataParallel' object has no attribute 'get_input_embeddings'\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import torch.multiprocessing as mp\n",
+ "\n",
+ "def init_distributed(rank, world_size, local_rank=0): # Add local_rank argument\n",
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ " os.environ[\"MASTER_PORT\"] = \"12345\" # Adjust port if needed\n",
+ " if rank == 0:\n",
+ " print(\"Initializing distributed process group...\")\n",
+ " torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank)\n",
+ " torch.cuda.set_device(local_rank) # Set unique GPU device for each process\n",
+ "\n",
+ "def cleanup_distributed():\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "def main_worker(rank, world_size):\n",
+ " local_rank = rank % torch.cuda.device_count() # Assign unique local rank\n",
+ " init_distributed(rank, world_size, local_rank)\n",
+ " # Your model training and fine-tuning code goes here with model on local_rank GPU\n",
+ " finetune() # Move model to assigned GPU\n",
+ " cleanup_distributed()\n",
+ "if __name__ == \"__main__\":\n",
+ " world_size = torch.cuda.device_count()\n",
+ "\n",
+ " # Workaround for Jupyter Notebook and interactive environments\n",
+ " processes = []\n",
+ " for rank in range(world_size):\n",
+ " p = mp.Process(target=main_worker, args=(rank, world_size))\n",
+ " p.start()\n",
+ " processes.append(p)\n",
+ "\n",
+ " for p in processes:\n",
+ " p.join()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# How to FineTune Llama 3 with Unsloth\n",
+ "Hello everyone, today we are going to show how we can Fine Tune Llama 2 with a Usloth package."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 5 - Loading packages\n",
+ "Once we have installed all the packages we load the packages."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2.2.1+cu121\n",
+ "12.1\n",
+ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
+ "Token is valid (permission: write).\n",
+ "Your token has been saved to /home/wsuser/.cache/huggingface/token\n",
+ "Login successful\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "import torch\n",
+ "from datasets import load_dataset\n",
+ "from huggingface_hub import notebook_login\n",
+ "from transformers import TrainingArguments\n",
+ "from trl import SFTTrainer\n",
+ "from unsloth import FastLanguageModel\n",
+ "print(torch.__version__)\n",
+ "print(torch.version.cuda)\n",
+ "token=\"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\"\n",
+ "from huggingface_hub import login, logout\n",
+ "login(token) # non-blocking login"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 6 - Setup configuration"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "**Model Configuration**\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_config={ \"model_config\": {\n",
+ " \"base_model\": \"meta-llama/Meta-Llama-3-8B-Instruct\", # The base model\n",
+ " \"finetuned_model\": \"ruslanmv/Medical-Mind-Llama-3-8b-1M\", # The finetuned model\n",
+ " \"finetuned_name\": \"Medical-Mind-Llama-3-8b-v1M\",\n",
+ " \"max_seq_length\": 2048, # The maximum sequence length\n",
+ " \"dtype\": None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
+ " \"load_in_4bit\": True, # Load the model in 4-bit\n",
+ "}}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "* `base_model`: specifies the pre-trained model to use as the base model for fine-tuning.\n",
+ "* `finetuned_model`: specifies the finetuned model to use for fine-tuning.\n",
+ "* `finetuned_name`: specifies the name of the finetuned model.\n",
+ "* `max_seq_length`: specifies the maximum sequence length that the model can process.\n",
+ "* `dtype`: specifies the data type to use for the model's weights and activations. `None` means auto-detection, which will choose the most suitable data type based on the hardware.\n",
+ "* `load_in_4bit`: specifies whether to load the model i 4-bit precision, which can reduce memory usage and improve performance.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**LoRA Configuration**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lora_config={\"lora_config\": {\n",
+ " \"r\": 16, # The number of LoRA layers 8, 16, 32, 64\n",
+ " \"target_modules\": [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
+ " \"gate_proj\", \"up_proj\", \"down_proj\"], # The target modules\n",
+ " \"lora_alpha\": 16, # The alpha value for LoRA\n",
+ " \"lora_dropout\": 0, # The dropout value for LoRA\n",
+ " \"bias\": \"none\", # The bias for LoRA\n",
+ " \"use_gradient_checkpointing\": True, # Use gradient checkpointing\n",
+ " \"use_rslora\": False, # Use RSLora\n",
+ " \"use_dora\": False, # Use DoRa\n",
+ " \"loftq_config\": None # The LoFTQ configuration\n",
+ "}\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "* `r`: specifies the number of LoRA layers to use.\n",
+ "* `target_modules`: specifies the modules to which LoRA should be applied.\n",
+ "* `lora_alpha`: specifies the alpha value for LoRA, which controls the strength of the LoRA layers.\n",
+ "* `lora_dropout`: specifies the dropout value for LoRA, which controls the random dropping of neurons during training.\n",
+ "* `bias`: specifies the bias for LoRA, which can be set to \"none\" or a specific value.\n",
+ "* `use_gradient_checkpointing`: specifies whether to use gradient checkpointing, which can reduce memory usage during training.\n",
+ "* `use_rslora` and `use_dora`: specify whether to use RSLora and DoRa, respectively, which are variants of LoRA.\n",
+ "* `loftq_config`: specifies the LoFTQ configuration, which is not used in this example.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Training Configuration**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "training_config={\"training_config\": {\n",
+ " \"per_device_train_batch_size\": 2, # The batch size\n",
+ " \"gradient_accumulation_steps\": 4, # The gradient accumulation steps\n",
+ " \"warmup_steps\": 5, # The warmup steps\n",
+ " \"max_steps\": 0, # The maximum steps (0 if the epochs are defined)\n",
+ " \"num_train_epochs\": 1, # The number of training epochs\n",
+ " \"learning_rate\": 2e-4, # The learning rate\n",
+ " \"fp16\": not torch.cuda.is_bf16_supported(), # The fp16\n",
+ " \"bf16\": torch.cuda.is_bf16_supported(), # The bf16\n",
+ " \"logging_steps\": 1, # The logging steps\n",
+ " \"optim\": \"adamw_8bit\", # The optimizer\n",
+ " \"weight_decay\": 0.0, # The weight decay\n",
+ " \"lr_scheduler_type\": \"linear\", # The learning rate scheduler\n",
+ " \"seed\": 42, # The seed\n",
+ " \"output_dir\": \"outputs\", # The output directory\n",
+ "}\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "* `per_device_train_batch_size`: specifies the batch size to use for training.\n",
+ "* `gradient_accumulation_steps`: specifies the number of steps to accumulate gradients before updating the model.\n",
+ "* `warmup_steps`: specifies the number of warmup steps to perform before starting training.\n",
+ "* `max_steps`: specifies the maximum number of steps to train for. If set to 0, the model will train for the specified number of epochs.\n",
+ "* `num_train_epochs`: specifies the number of epochs to train for.\n",
+ "* `learning_rate`: specifies the initial learning rate to use for training.\n",
+ "* `fp16` and `bf16`: specify whether to use 16-bit floating-point precision (fp16) or 16-bit bfloat precision (bf16) for training.\n",
+ "* `logging_steps`: specifies the number of steps to log training metrics.\n",
+ "* `optim`: specifies the optimizer to use for training.\n",
+ "* `weight_decay`: specifies the weight decay rate to use for regularization.\n",
+ "* `lr_scheduler_type`: specifies the learning rate scheduler to use.\n",
+ "* `seed`: specifies the random seed to use for training.\n",
+ "* `output_dir`: specifies the output directory to save training results."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Hugging Face Username**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hugging_face_username={\"hugging_face_username\": \"ruslanmv\"}\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Training Dataset**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "training_dataset={\"training_dataset\": {\n",
+ " \"name\": \"ruslanmv/ai-medical-dataset\", # The dataset name(huggingface/datasets)\n",
+ " \"split\": \"train\", # The dataset split\n",
+ " \"input_fields\": [\"question\", \"context\"] ,# The input fields\n",
+ " \"input_field\": \"text\",# The input field\n",
+ " },\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**`training_dataset`**: This is the top-level key for the dataset configuration.\n",
+ "\n",
+ "**`name`**: This specifies the name of the dataset. In this case, it's `ruslanmv/ai-medical-dataset`, which is a dataset hosted on the Hugging Face Hub. The format is `username/dataset_name`.\n",
+ "\n",
+ "**`split`**: This specifies the split of the dataset to use for training. In this case, it's set to `\"train\"`, which means the model will be trained on the training split of the dataset.\n",
+ "\n",
+ "**`input_fields`**: This specifies the input fields of the dataset that will be used for trainine, it's a list containing two fields: `\"question\"` and `\"context\"`. These fields are likely to be the input features of the dataset.\n",
+ "\n",
+ "**`input_field`**: This specifies the primary input field of the dataset. In this case, it's set to `\"text\"`. This field is likely to be the text input that the model will process.\n",
+ "\n",
+ "Here's an example of what this dataset might look like:\n",
+ "\n",
+ "| question | context | text |\n",
+ "| --- | --- | --- |\n",
+ "| How does COVID-19 spread? | COVID-19 is a respiratory disease... | The COVID-19 is.. |\n",
+ "| ... | ... | ... |"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = {}\n",
+ "config.update(hugging_face_username)\n",
+ "config.update(model_config)\n",
+ "config.update(lora_config)\n",
+ "config.update(training_config)\n",
+ "config.update(training_dataset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "def save_config_to_json(config, file_path):\n",
+ " with open(file_path, 'w') as f:\n",
+ " json.dump(config, f, indent=4)\n",
+ "file_path = \"original_config.json\"\n",
+ "save_config_to_json(config, file_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==((====))== Unsloth: Fast Llama patching release 2024.4\n",
+ " \\\\ /| GPU: Tesla V100-PCIE-16GB. Max memory: 15.773 GB. Platform = Linux.\n",
+ "O^O/ \\_/ \\ Pytorch: 2.2.1+cu121. CUDA = 7.0. CUDA Toolkit = 12.1.\n",
+ "\\ / Bfloat16 = FALSE. Xformers = 0.0.25. FA = False.\n",
+ " \"-____-\" Free Apache license: http://github.com/unslothai/unsloth\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in .\n",
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Loading the model and the tokenizer for the model\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name=config[\"model_config\"].get(\"base_model\"),\n",
+ " max_seq_length=config[\"model_config\"].get(\"max_seq_length\"),\n",
+ " dtype=config[\"model_config\"].get(\"dtype\"),\n",
+ " load_in_4bit=config[\"model_config\"].get(\"load_in_4bit\"),\n",
+ "\n",
+ "\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Setup for QLoRA/LoRA peft of the base model\n",
+ "model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r = config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules = config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha = config.get(\"lora_config\").get(\"lora_alpha\"),\n",
+ " lora_dropout = config.get(\"lora_config\").get(\"lora_dropout\"),\n",
+ " bias = config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing = config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state = 42,\n",
+ " use_rslora = config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora = config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config = config.get(\"lora_config\").get(\"loftq_config\"),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n",
+ "#tokenizer = AutoTokenizer.from_pretrained(config.get(\"model_config\").get(\"base_model\"))\n",
+ "tokenizer.add_eos_token = True\n",
+ "tokenizer.pad_token_id = 0\n",
+ "tokenizer.padding_side = \"left\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset does not exist.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2f16e19d432e4cfb8518dc80b1f54552",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading readme: 0%| | 0.00/2.97k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "05063beb14274cbbaf203a403477efde",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Resolving data files: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2833c846be424c57a2240601530d430d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Resolving data files: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f45497aa51104f80b0501fd0c8a2efe4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading data: 0%| | 0/18 [00:00, ?files/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2a9cd3f9409b4055a2e8f25249c6be07",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating train split: 0%| | 0/21210000 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d2deb39e743440abb7d7fbca11e7e864",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading dataset shards: 0%| | 0/18 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c46c30d45dbe4234a6390594356ed6cb",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/100 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0dfd1af07b0849a1ba9476cea77f6c5b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Saving the dataset (0/1 shards): 0%| | 0/100 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "is_test=True\n",
+ "import datasets\n",
+ "import os\n",
+ "dataset_path = \"train_dataset\"\n",
+ "if os.path.exists(dataset_path):\n",
+ " print(\"Dataset exists!\")\n",
+ " train_dataset = datasets.load_from_disk(\"train_dataset\")\n",
+ "else:\n",
+ " print(\"Dataset does not exist.\")\n",
+ " # Loading the training dataset\n",
+ " train_dataset = load_dataset(config.get(\"training_dataset\").get(\"name\"), split = config.get(\"training_dataset\").get(\"split\")) \n",
+ " \n",
+ " if is_test:\n",
+ " # Select the first 1M rows of the dataset\n",
+ " train_dataset = train_dataset.select(range(100))\n",
+ " \n",
+ " medical_prompt = \"\"\"You are an AI Medical Assistant Chatbot, trained to answer medical questions. Below is an instruction that describes a task, paired with an response context. Write a response that appropriately completes the request.\n",
+ " ### Instruction:\n",
+ " {}\n",
+ "\n",
+ " ### Response:\n",
+ " {}\"\"\"\n",
+ " EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN\n",
+ " def formatting_prompts_func(examples):\n",
+ " instructions = examples[\"question\"]\n",
+ " outputs = examples[\"context\"]\n",
+ " texts = []\n",
+ " for instruction, output in zip(instructions, outputs):\n",
+ " # Must add EOS_TOKEN, otherwise your generation will go on forever!\n",
+ " text = medical_prompt.format(instruction, output) + EOS_TOKEN\n",
+ " texts.append(text)\n",
+ " return { \"text\" : texts, }\n",
+ " pass\n",
+ " train_dataset= train_dataset.map(formatting_prompts_func, batched = True,)\n",
+ " train_dataset['text'][1] \n",
+ " import datasets\n",
+ " # assume 'test_dataset' is a Dataset object\n",
+ " train_dataset.save_to_disk(\"train_dataset\") "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Dataset({\n",
+ " features: ['question', 'context', 'text'],\n",
+ " num_rows: 100\n",
+ "})"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "is_multiple=True"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Multiple GPUs enabled\n"
+ ]
+ }
+ ],
+ "source": [
+ "if is_multiple:\n",
+ " # Set up GPU acceleration\n",
+ " if torch.cuda.device_count() > 1:\n",
+ " print(\"Multiple GPUs enabled\")\n",
+ " devices = [f\"cuda:{i}\" for i in range(torch.cuda.device_count())]\n",
+ " model_parallel = torch.nn.DataParallel(model, device_ids= devices ) #[0, 1]\n",
+ " # Access the original model from the DataParallel object\n",
+ " model = model_parallel.module\n",
+ " else:\n",
+ " print(\"No DataParallel \")\n",
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ " model.to(device) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['cuda:0', 'cuda:1']"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "devices"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PeftModelForCausalLM(\n",
+ " (base_model): LoraModel(\n",
+ " (model): LlamaForCausalLM(\n",
+ " (model): LlamaModel(\n",
+ " (embed_tokens): Embedding(128256, 4096)\n",
+ " (layers): ModuleList(\n",
+ " (0-31): 32 x LlamaDecoderLayer(\n",
+ " (self_attn): LlamaSdpaAttention(\n",
+ " (q_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Identity()\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=16, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=16, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (k_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Identity()\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=16, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=16, out_features=1024, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (v_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Identity()\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=16, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=16, out_features=1024, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (o_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Identity()\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=16, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=16, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (rotary_emb): LlamaRotaryEmbedding()\n",
+ " )\n",
+ " (mlp): LlamaMLP(\n",
+ " (gate_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Identity()\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=16, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=16, out_features=14336, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (up_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Identity()\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=16, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=16, out_features=14336, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (down_proj): lora.Linear4bit(\n",
+ " (base_layer): Linear4bit(in_features=14336, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Identity()\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=14336, out_features=16, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=16, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (act_fn): SiLU()\n",
+ " )\n",
+ " (input_layernorm): LlamaRMSNorm()\n",
+ " (post_attention_layernorm): LlamaRMSNorm()\n",
+ " )\n",
+ " )\n",
+ " (norm): LlamaRMSNorm()\n",
+ " )\n",
+ " (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "abb0b4aadff048d9b612a7a8e6ebd4a9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map (num_proc=2): 0%| | 0/100 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:318: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.\n",
+ " warnings.warn(\n",
+ "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Setting up the trainer for the model\n",
+ "trainer = SFTTrainer(\n",
+ " model = model,\n",
+ " tokenizer = tokenizer,\n",
+ " train_dataset = train_dataset,\n",
+ " dataset_text_field = config.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length = config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc = 2,\n",
+ " packing = False,\n",
+ " args = TrainingArguments(\n",
+ " per_device_train_batch_size = config.get(\"training_config\").get(\"per_device_train_batch_size\"),\n",
+ " gradient_accumulation_steps = config.get(\"training_config\").get(\"gradient_accumulation_steps\"),\n",
+ " warmup_steps = config.get(\"training_config\").get(\"warmup_steps\"),\n",
+ " max_steps = config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs= config.get(\"training_config\").get(\"num_train_epochs\"),\n",
+ " learning_rate = config.get(\"training_config\").get(\"learning_rate\"),\n",
+ " fp16 = config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16 = config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps = config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim = config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay = config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type = config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed = 42,\n",
+ " output_dir = config.get(\"training_config\").get(\"output_dir\"),\n",
+ " \n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reserved Memory: 11.19GB\n",
+ "Max Memory: 15.77GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Memory statistics before training\n",
+ "gpu_statistics = torch.cuda.get_device_properties(0)\n",
+ "reserved_memory = round(torch.cuda.max_memory_reserved() / 1024**3, 2)\n",
+ "max_memory = round(gpu_statistics.total_memory / 1024**3, 2)\n",
+ "print(f\"Reserved Memory: {reserved_memory}GB\")\n",
+ "print(f\"Max Memory: {max_memory}GB\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "id": "Q-g-4RvNXkyD"
+ },
+ "outputs": [],
+ "source": [
+ "## [ 1038/2651250 53:49 < 2295:10:28, 0.32 it/s, Epoch 0.00/1] old"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 948
+ },
+ "id": "yI9mEQ7ZOUx2",
+ "outputId": "6466d591-76f8-45e2-e665-39ad9bf8ae7f",
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "==((====))== Unsloth - 2x faster free finetuning | Num GPUs = 1\n",
+ " \\\\ /| Num examples = 100 | Num Epochs = 1\n",
+ "O^O/ \\_/ \\ Batch size per device = 4 | Gradient Accumulation steps = 4\n",
+ "\\ / Total batch size = 16 | Total steps = 6\n",
+ " \"-____-\" Number of trainable parameters = 41,943,040\n"
+ ]
+ },
+ {
+ "ename": "RuntimeError",
+ "evalue": "Caught RuntimeError in replica 1 on device 1.\nOriginal Traceback (most recent call last):\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py\", line 83, in _worker\n output = module(*input, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 882, in PeftModelForCausalLM_fast_forward\n return self.base_model(\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/peft/tuners/tuners_utils.py\", line 161, in forward\n return self.model.forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/accelerate/hooks.py\", line 166, in new_forward\n output = module._old_forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 813, in _CausalLM_fast_forward\n outputs = self.model(\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/accelerate/hooks.py\", line 166, in new_forward\n output = module._old_forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 668, in LlamaModel_fast_forward\n layer_outputs = torch.utils.checkpoint.checkpoint(\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/_compile.py\", line 24, in inner\n return torch._dynamo.disable(fn, recursive)(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py\", line 489, in _fn\n return fn(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/_dynamo/external_utils.py\", line 17, in inner\n return fn(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/utils/checkpoint.py\", line 482, in checkpoint\n return CheckpointFunction.apply(function, preserve, *args)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/autograd/function.py\", line 553, in apply\n return super().apply(*args, **kwargs) # type: ignore[misc]\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/utils/checkpoint.py\", line 261, in forward\n outputs = run_function(*args)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 664, in custom_forward\n return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/accelerate/hooks.py\", line 166, in new_forward\n output = module._old_forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 433, in LlamaDecoderLayer_fast_forward\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/accelerate/hooks.py\", line 166, in new_forward\n output = module._old_forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 308, in LlamaAttention_fast_forward\n Q, K, V = self.apply_qkv(self, hidden_states)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/kernels/fast_lora.py\", line 312, in apply_lora_qkv\n Q, K, V = LoRA_QKV.apply(X,\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/autograd/function.py\", line 553, in apply\n return super().apply(*args, **kwargs) # type: ignore[misc]\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py\", line 115, in decorate_fwd\n return fwd(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/kernels/fast_lora.py\", line 227, in forward\n Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/kernels/utils.py\", line 225, in matmul_lora\n W = fast_dequantize(W.t(), W_quant)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/kernels/utils.py\", line 104, in fast_dequantize\n out_absmax += offset\nRuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!\n",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[28], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Training the model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m trainer_stats \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:361\u001b[0m, in \u001b[0;36mSFTTrainer.train\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneftune_noise_alpha \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer_supports_neftune:\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trl_activate_neftune(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel)\n\u001b[0;32m--> 361\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 363\u001b[0m \u001b[38;5;66;03m# After training we make sure to retrieve back the original forward pass method\u001b[39;00m\n\u001b[1;32m 364\u001b[0m \u001b[38;5;66;03m# for the embedding layer by removing the forward post hook.\u001b[39;00m\n\u001b[1;32m 365\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneftune_noise_alpha \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer_supports_neftune:\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/transformers/trainer.py:1859\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1857\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1859\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1860\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1861\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1862\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1863\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1864\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m:361\u001b[0m, in \u001b[0;36m_fast_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/transformers/trainer.py:3138\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 3137\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3138\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3140\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 3141\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/transformers/trainer.py:3161\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 3159\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3160\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 3161\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m 3163\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m 3164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:185\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodule_kwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 184\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 185\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:200\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\u001b[38;5;28mself\u001b[39m, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Any]:\n\u001b[0;32m--> 200\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:108\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 106\u001b[0m output \u001b[38;5;241m=\u001b[39m results[i]\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, ExceptionWrapper):\n\u001b[0;32m--> 108\u001b[0m \u001b[43moutput\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m outputs\u001b[38;5;241m.\u001b[39mappend(output)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
+ "File \u001b[0;32m/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/_utils.py:722\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 718\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 719\u001b[0m \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m 720\u001b[0m \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m 721\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 722\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: Caught RuntimeError in replica 1 on device 1.\nOriginal Traceback (most recent call last):\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py\", line 83, in _worker\n output = module(*input, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 882, in PeftModelForCausalLM_fast_forward\n return self.base_model(\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/peft/tuners/tuners_utils.py\", line 161, in forward\n return self.model.forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/accelerate/hooks.py\", line 166, in new_forward\n output = module._old_forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 813, in _CausalLM_fast_forward\n outputs = self.model(\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/accelerate/hooks.py\", line 166, in new_forward\n output = module._old_forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 668, in LlamaModel_fast_forward\n layer_outputs = torch.utils.checkpoint.checkpoint(\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/_compile.py\", line 24, in inner\n return torch._dynamo.disable(fn, recursive)(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py\", line 489, in _fn\n return fn(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/_dynamo/external_utils.py\", line 17, in inner\n return fn(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/utils/checkpoint.py\", line 482, in checkpoint\n return CheckpointFunction.apply(function, preserve, *args)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/autograd/function.py\", line 553, in apply\n return super().apply(*args, **kwargs) # type: ignore[misc]\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/utils/checkpoint.py\", line 261, in forward\n outputs = run_function(*args)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 664, in custom_forward\n return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/accelerate/hooks.py\", line 166, in new_forward\n output = module._old_forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 433, in LlamaDecoderLayer_fast_forward\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1511, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1520, in _call_impl\n return forward_call(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/accelerate/hooks.py\", line 166, in new_forward\n output = module._old_forward(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/models/llama.py\", line 308, in LlamaAttention_fast_forward\n Q, K, V = self.apply_qkv(self, hidden_states)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/kernels/fast_lora.py\", line 312, in apply_lora_qkv\n Q, K, V = LoRA_QKV.apply(X,\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/autograd/function.py\", line 553, in apply\n return super().apply(*args, **kwargs) # type: ignore[misc]\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py\", line 115, in decorate_fwd\n return fwd(*args, **kwargs)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/kernels/fast_lora.py\", line 227, in forward\n Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/kernels/utils.py\", line 225, in matmul_lora\n W = fast_dequantize(W.t(), W_quant)\n File \"/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/unsloth/kernels/utils.py\", line 104, in fast_dequantize\n out_absmax += offset\nRuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Training the model\n",
+ "trainer_stats = trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YQFEr64koXp1",
+ "outputId": "2e1b3775-1f5d-4b8e-a0ef-32266cb7fa2a"
+ },
+ "outputs": [],
+ "source": [
+ "# Memory statistics after training\n",
+ "used_memory = round(torch.cuda.max_memory_allocated() / 1024**3, 2)\n",
+ "used_memory_lora = round(used_memory - reserved_memory, 2)\n",
+ "used_memory_persentage = round((used_memory / max_memory) * 100, 2)\n",
+ "used_memory_lora_persentage = round((used_memory_lora / max_memory) * 100, 2)\n",
+ "print(f\"Used Memory: {used_memory}GB ({used_memory_persentage}%)\")\n",
+ "print(f\"Used Memory for training(fine-tuning) LoRA: {used_memory_lora}GB ({used_memory_lora_persentage}%)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_model=config.get(\"model_config\").get(\"finetuned_model\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1YJB4bZyoXp1"
+ },
+ "outputs": [],
+ "source": [
+ "# Saving the trainer stats\n",
+ "with open(\"trainer_stats.json\", \"w\") as f:\n",
+ " json.dump(trainer_stats, f, indent=4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "## Save and push the adapter to HF\n",
+ "import os\n",
+ "current_directory = os.getcwd()\n",
+ "# New model name\n",
+ "new_model = config.get(\"model_config\").get(\"finetuned_name\") #\"Medical-Mind-Llama-3-8b\"\n",
+ "# Save the fine-tuned model\n",
+ "save_path = os.path.join(current_directory, \"models\", new_model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist\n",
+ "#trainer.model.save_pretrained(save_path)\n",
+ "tokenizer.save_pretrained(save_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "help(model.save_pretrained_merged)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To save the final model as LoRA adapters, either use Huggingface's push_to_hub for an online save or save_pretrained for a local save.\n",
+ "\n",
+ "[NOTE] This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Save the model to the created directory\n",
+ "# `lora`: Save LoRA adapters with no merging. Useful for HF inference.\n",
+ "#model.save_pretrained(save_path)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Saving the model using merged_16bit(float16), \n",
+ "#`16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.\n",
+ "#model.save_pretrained_merged(save_path, tokenizer, save_method = \"merged_16bit\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.\n",
+ "model.save_pretrained_merged(save_path, tokenizer, save_method = \"merged_4bit_forced\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "save_path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get the list of files in the directory\n",
+ "files_in_model_dir = os.listdir(save_path)\n",
+ "# Print the list of files\n",
+ "print(\"Files in the directory:\")\n",
+ "for file in files_in_model_dir:\n",
+ " print(file)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from huggingface_hub import HfApi\n",
+ "def upload_folder(folder_path, repository_name, path_in_repo):\n",
+ " api = HfApi()\n",
+ " \n",
+ " # Check if the repository exists, if not, create it\n",
+ " repo_exists = api.repo_exists(repository_name)\n",
+ " if not repo_exists:\n",
+ " api.create_repo(repository_name)\n",
+ " print(f\"Repository '{repository_name}' created on Hugging Face Hub.\")\n",
+ "\n",
+ " for root, dirs, files in os.walk(folder_path):\n",
+ " for file in files:\n",
+ " file_path = os.path.join(root, file)\n",
+ " relative_path = os.path.relpath(file_path, folder_path)\n",
+ " repo_path = os.path.join(path_in_repo, relative_path)\n",
+ " api.upload_file(path_or_fileobj=file_path, repo_id=repository_name, path_in_repo=repo_path)\n",
+ " print(f\"{repo_path} uploaded to {repository_name}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define the repository name and path in the repository\n",
+ "repository_name = \"ruslanmv/\"+new_model\n",
+ "path_in_repo = \"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "repository_name"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Upload the folder and its contents to the repository\n",
+ "upload_folder(save_path, repository_name, path_in_repo)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#help(model.push_to_hub_merged)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#save_path='/home/wsuser/work/models/Medical-Mind-Llama-3-8b'\n",
+ "#repo_id='ruslanmv/Medical-Mind-Llama-3-8b'\n",
+ "#commit_message=\"Uploading Model\"\n",
+ "#model.push_to_hub_merged(repo_id, tokenizer, save_method = \"merged_16bit\",commit_message=commit_message)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#model.push_to_hub_merged(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, save_method = \"merged_4bit\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#model.save_pretrained_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer)\n",
+ "#model.push_to_hub_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer,repository_private=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#model.save_pretrained_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, quantization_method = \"q4_k_m\")\n",
+ "#model.push_to_hub_gguf(config.get(\"model_config\").get(\"finetuned_model\"), tokenizer, quantization_method = \"q4_k_m\",private=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/envs/Python-RT23.1-CUDA/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8f1cc811458d464c906abab21b49d230",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.12k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==((====))== Unsloth: Fast Llama patching release 2024.4\n",
+ " \\\\ /| GPU: Tesla V100-PCIE-16GB. Max memory: 15.773 GB. Platform = Linux.\n",
+ "O^O/ \\_/ \\ Pytorch: 2.2.1+cu121. CUDA = 7.0. CUDA Toolkit = 12.1.\n",
+ "\\ / Bfloat16 = FALSE. Xformers = 0.0.25. FA = False.\n",
+ " \"-____-\" Free Apache license: http://github.com/unslothai/unsloth\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Unused kwargs: ['quant_method']. These kwargs are not used in .\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c900fe27e6ab4026bf82f8d9fc8127ce",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors.index.json: 0%| | 0.00/132k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "71c34d1aab5b4eafa1ac4bedea7a99e1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a86b815277ad420c98f60f36e03d1db6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00001-of-00002.safetensors: 0%| | 0.00/4.65G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "97cb140a6fc0412eb54980d172002cb8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00002-of-00002.safetensors: 0%| | 0.00/1.05G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7dfd858cf250481795c44a4a255e8963",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7712983e35464f43b533191252bc6c26",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "generation_config.json: 0%| | 0.00/143 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "dc245edfb09f47aeb8d2ca1c0e4b3085",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer_config.json: 0%| | 0.00/51.0k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2089efc7021b4f98a904e6310b8552a6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer.json: 0%| | 0.00/9.09M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "241ce29274d845819b32f30e334f2cad",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "special_tokens_map.json: 0%| | 0.00/321 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Answer: I'd be happy to help you with that!\n",
+ "\n",
+ "The main cause of inflammatory CD4+ T cells is a complex process involving multiple factors. CD4+ T cells, also known as T helper cells, play a crucial role in the immune response. Inflammation occurs when these cells become activated and produce pro-inflammatory cytokines, leading to an imbalance in the immune response.\n",
+ "\n",
+ "Some common causes of inflammatory CD4+ T cells include:\n",
+ "\n",
+ "1. **Infections**: Bacterial, viral, or fungal infections can trigger an immune response, leading to the activation of CD4+ T cells and subsequent inflammation.\n",
+ "2. **Autoimmune disorders**: Conditions like rheumatoid arthritis, lupus, or multiple sclerosis can cause CD4+ T cells to become chronically activated, leading to chronic inflammation.\n",
+ "3. **Allergies**: Allergic reactions can trigger the activation of CD4+ T cells, resulting in the release of pro-inflammatory cytokines and the development of inflammation.\n",
+ "4. **Cancer**: Tumor cells can stimulate the activation of CD4+ T cells, leading to an inflammatory response.\n",
+ "5. **Environmental factors**: Exposure to pollutants, toxins, or other environmental stressors can contribute to the activation of CD4+ T cells and the development\n"
+ ]
+ }
+ ],
+ "source": [
+ "is_inference=False\n",
+ "if is_inference:\n",
+ " from unsloth import FastLanguageModel\n",
+ " import torch\n",
+ " max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!\n",
+ " dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
+ " load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.\n",
+ " model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = \"ruslanmv/Medical-Mind-Llama-3-8b-1M\",\n",
+ " max_seq_length = max_seq_length,\n",
+ " dtype = dtype,\n",
+ " load_in_4bit = load_in_4bit,\n",
+ " )\n",
+ " # Using FastLanguageModel for fast inference\n",
+ " FastLanguageModel.for_inference(model)\n",
+ " question=\"This is the question: What was the main cause of the inflammatory CD4+ T cells?\"\n",
+ " prompt=f\"<|start_header_id|>system<|end_header_id|> You are a Medical AI chatbot assistant .<|eot_id|><|start_header_id|> user <|end_header_id|>{question}<|eot_id|>\"\n",
+ " # Tokenizing the input and generating the output\n",
+ " inputs = tokenizer([prompt], return_tensors = \"pt\").to(\"cuda\")\n",
+ " outputs = model.generate(**inputs, max_new_tokens = 256, use_cache = True)\n",
+ " answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] # Get the first element from the batch\n",
+ "\n",
+ " # Split the answer at the first line break, assuming system intro and question are on separate lines\n",
+ " answer_parts = answer.split(\"\\n\", 1)\n",
+ "\n",
+ " # If there are multiple parts, consider the second part as the answer\n",
+ " if len(answer_parts) > 1:\n",
+ " answer = answer_parts[1].strip() # Remove leading/trailing whitespaces\n",
+ " else:\n",
+ " answer = \"\" # If no split possible, set answer to empty string\n",
+ "\n",
+ " print(f\"Answer: {answer}\") \n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Hyperparameter search\n",
+ "**Step 1: Define the Hyperparameter Search Space**\n",
+ "We need to define the search space for the hyperparameters we want to tune. For example, let's say we want to tune the following hyperparameters:\n",
+ "\n",
+ "* `learning_rate`\n",
+ "* `per_device_train_batch_size`\n",
+ "* `gradient_accumulation_steps`\n",
+ "* `warmup_steps`\n",
+ "* `num_train_epochs`\n",
+ "* `lora_alpha`\n",
+ "* `lora_dropout`\n",
+ "\n",
+ "We can define the search space as follows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import hyperopt\n",
+ "from hyperopt import hp\n",
+ "from hyperopt import Trials\n",
+ "from hyperopt import fmin, tpe, Trials\n",
+ "# Define the search space for hyperparameters\n",
+ "space = {\n",
+ " 'learning_rate': hp.loguniform('learning_rate', -5, -1), # Learning rate in log scale\n",
+ " #'lora_alpha': hp.quniform('lora_alpha', 1, 32, 1), # LoRA alpha with quantized steps\n",
+ " #'lora_dropout': hp.uniform('lora_dropout', 0, 0.5), # LoRA dropout rate\n",
+ "\n",
+ " 'per_device_train_batch_size': hp.quniform('per_device_train_batch_size', 2, 16, q=1), \n",
+ " 'gradient_accumulation_steps': hp.quniform('gradient_accumulation_steps', 1, 8, 1), # Added for exploration\n",
+ " # Uncomment these if you want to tune other hyperparameters\n",
+ " # 'warmup_steps': hp.quniform('warmup_steps', 0, 1000, 1),\n",
+ " # 'num_train_epochs': hp.quniform('num_train_epochs', 1, 5, 1), \n",
+ "\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Step 2. Define the Objective Function**\n",
+ "\n",
+ "The objective function is a function that takes in the hyperparameters, sets them in the `config` dictionary, trains the model, and returns the loss or metric to minimize. We need to modify the previous fine-tuning code to define the objective function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def objective(params):\n",
+ " # Set hyperparameters in the config dictionary (assuming it's defined elsewhere)\n",
+ " config['training_config']['learning_rate'] = params['learning_rate']\n",
+ " # config['lora_config']['lora_alpha'] = params['lora_alpha']\n",
+ " # config['lora_config']['lora_dropout'] = params['lora_dropout'] \n",
+ " config['training_config']['per_device_train_batch_size'] = params['per_device_train_batch_size']\n",
+ " config['training_config']['gradient_accumulation_steps'] = params['gradient_accumulation_steps']\n",
+ " # ... Set other hyperparameters from params dictionary ... \n",
+ " #config['training_config']['warmup_steps'] = params['warmup_steps']\n",
+ " #config['training_config']['num_train_epochs'] = params['num_train_epochs']\n",
+ "\n",
+ " # Load the model and tokenizer (assuming these are defined elsewhere)\n",
+ " try:\n",
+ " model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name=config.get(\"model_config\").get(\"base_model\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype=config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit=config.get(\"model_config\").get(\"load_in_4bit\"),\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error loading model and tokenizer: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ "\n",
+ " # Setup LoRA for the model (assuming FastLanguageModel supports LoRA)\n",
+ " try:\n",
+ " model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r=config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules=config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha=params['lora_alpha'],\n",
+ " lora_dropout=params['lora_dropout'],\n",
+ " bias=config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing=config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state=42,\n",
+ " use_rslora=config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora=config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config=config.get(\"lora_config\").get(\"loftq_config\")\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error setting up LoRA: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ " # Train the model on the test dataset (assuming SFTTrainer and training arguments are defined)\n",
+ " try:\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " tokenizer=tokenizer,\n",
+ " train_dataset=train_dataset,\n",
+ " dataset_text_field=config.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length=config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc=2,\n",
+ " packing=False,\n",
+ " args=TrainingArguments(\n",
+ " per_device_train_batch_size=int(params['per_device_train_batch_size']),\n",
+ " gradient_accumulation_steps=params['gradient_accumulation_steps'],\n",
+ " warmup_steps=params['warmup_steps'],\n",
+ " max_steps=config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs=params['num_train_epochs'],\n",
+ " learning_rate=params['learning_rate'],\n",
+ " fp16=config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16=config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps=config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim=config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay=config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type=config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed=42,\n",
+ " output_dir=config.get(\"training_config\").get(\"output_dir\")\n",
+ " )\n",
+ " )\n",
+ " trainer_stats = trainer.train()\n",
+ " return trainer_stats.loss # Assuming loss is the metric to minimize\n",
+ " except Exception as e:\n",
+ " print(f\"Error during training: {e}\")\n",
+ " return float(\"inf\") # Return high value for failed trials\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Step 3: Perform Hyperparameter Search**\n",
+ "\n",
+ "Now that we have defined the objective function, we can perform the hyperparameter search using Hyperopt's `fmin` function. We need to specify the objective function, the search space, and the maximum number of evaluations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "# Create a Trials object to track hyperparameter evaluations\n",
+ "trials = Trials()\n",
+ "# Run hyperparameter optimization using TPE algorithm\n",
+ "best = fmin(objective, space, algo=tpe.suggest, trials=trials, max_evals=2)\n",
+ "# Print the best hyperparameters found during optimization\n",
+ "print(\"Best Hyperparameters:\", best)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import gc\n",
+ "def reset_gpu_memory():\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " print(\"GPU memory cleared!\")\n",
+ "# Example usage:\n",
+ "reset_gpu_memory()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Full code version"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#### Fixed Version\n",
+ "import hyperopt\n",
+ "from hyperopt import hp\n",
+ "from hyperopt import Trials\n",
+ "from hyperopt import fmin, tpe, Trials\n",
+ "# Define the search space for hyperparameters\n",
+ "space = {\n",
+ " 'learning_rate' : hp.loguniform('learning_rate', -5, -1), # Learning rate in log scale\n",
+ " 'per_device_train_batch_size': hp.quniform('per_device_train_batch_size', 2, 16, 1), \n",
+ " 'gradient_accumulation_steps': hp.quniform('gradient_accumulation_steps', 1, 8, 1), \n",
+ " # Uncomment these if you want to tune them\n",
+ " #'lora_alpha' : hp.quniform('lora_alpha', 1, 32, 1), # LoRA alpha with quantized steps\n",
+ " #'lora_dropout' : hp.uniform('lora_dropout', 0, 0.5), # LoRA dropout rate\n",
+ " # 'warmup_steps' : hp.quniform('warmup_steps', 0, 1000, 1),\n",
+ " # 'num_train_epochs' : hp.quniform('num_train_epochs', 1, 5, 1),\n",
+ "}\n",
+ "def objective(params):\n",
+ " # Set hyperparameters in the config dictionary (assuming it's defined elsewhere)\n",
+ " config['training_config']['learning_rate']=params['learning_rate']\n",
+ " config['training_config']['per_device_train_batch_size'] = params['per_device_train_batch_size']\n",
+ " config['training_config']['gradient_accumulation_steps'] = params['gradient_accumulation_steps'] \n",
+ " # config['lora_config']['lora_alpha'] = params['lora_alpha']\n",
+ " # config['lora_config']['lora_dropout'] = params['lora_dropout']\n",
+ " # ... Set other hyperparameters from params dictionary ...\n",
+ " #config['training_config']['warmup_steps'] = params['warmup_steps']\n",
+ " #config['training_config']['num_train_epochs'] = params['num_train_epochs']\n",
+ " # Load the model and tokenizer (assuming these are defined elsewhere) \n",
+ " try:\n",
+ " model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = config.get(\"model_config\").get(\"base_model\"),\n",
+ " max_seq_length = config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dtype = config.get(\"model_config\").get(\"dtype\"),\n",
+ " load_in_4bit = config.get(\"model_config\").get(\"load_in_4bit\")\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error loading model and tokenizer: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ "\n",
+ " # Setup LoRA for the model (assuming FastLanguageModel supports LoRA)\n",
+ " try:\n",
+ " model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r = config.get(\"lora_config\").get(\"r\"),\n",
+ " target_modules = config.get(\"lora_config\").get(\"target_modules\"),\n",
+ " lora_alpha = config.get(\"lora_config\").get('lora_alpha'), #params['lora_alpha'],\n",
+ " lora_dropout = config.get(\"lora_config\").get('lora_dropout'),#params['lora_dropout'],\n",
+ " bias = config.get(\"lora_config\").get(\"bias\"),\n",
+ " use_gradient_checkpointing = config.get(\"lora_config\").get(\"use_gradient_checkpointing\"),\n",
+ " random_state = 42,\n",
+ " use_rslora = config.get(\"lora_config\").get(\"use_rslora\"),\n",
+ " use_dora = config.get(\"lora_config\").get(\"use_dora\"),\n",
+ " loftq_config = config.get(\"lora_config\").get(\"loftq_config\")\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error setting up LoRA: {e}\")\n",
+ " return float(\"inf\") # Return high value for errors\n",
+ " # Train the model on the test dataset (assuming SFTTrainer and training arguments are defined)\n",
+ " try:\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " tokenizer = tokenizer,\n",
+ " train_dataset = train_dataset,\n",
+ " dataset_text_field = config.get(\"training_dataset\").get(\"input_field\"),\n",
+ " max_seq_length = config.get(\"model_config\").get(\"max_seq_length\"),\n",
+ " dataset_num_proc = 2,\n",
+ " packing = False,\n",
+ " args=TrainingArguments(\n",
+ " per_device_train_batch_size = int(params['per_device_train_batch_size']), #config.get(\"training_config\").get('per_device_train_batch_size'),\n",
+ " gradient_accumulation_steps = params['gradient_accumulation_steps'], #config.get(\"training_config\").get('gradient_accumulation_steps'),\n",
+ " warmup_steps = config.get(\"training_config\").get('warmup_steps'),#params['warmup_steps'],\n",
+ " max_steps = config.get(\"training_config\").get(\"max_steps\"),\n",
+ " num_train_epochs = config.get(\"training_config\").get('num_train_epochs'),#params['num_train_epochs'],\n",
+ " learning_rate = params['learning_rate'],\n",
+ " fp16 = config.get(\"training_config\").get(\"fp16\"),\n",
+ " bf16 = config.get(\"training_config\").get(\"bf16\"),\n",
+ " logging_steps = config.get(\"training_config\").get(\"logging_steps\"),\n",
+ " optim = config.get(\"training_config\").get(\"optim\"),\n",
+ " weight_decay = config.get(\"training_config\").get(\"weight_decay\"),\n",
+ " lr_scheduler_type = config.get(\"training_config\").get(\"lr_scheduler_type\"),\n",
+ " seed = 42,\n",
+ " output_dir = config.get(\"training_config\").get(\"output_dir\")\n",
+ " )\n",
+ " )\n",
+ " trainer_stats = trainer.train()\n",
+ " return trainer_stats.loss # Assuming loss is the metric to minimize\n",
+ " except Exception as e:\n",
+ " print(f\"Error during training: {e}\")\n",
+ " return float(\"inf\") # Return high value for failed trials \n",
+ "# Create a Trials object to track hyperparameter evaluations\n",
+ "trials = Trials()\n",
+ "# Run hyperparameter optimization using TPE algorithm\n",
+ "best = fmin(objective, space, algo=tpe.suggest, trials=trials, max_evals=2)\n",
+ "# Print the best hyperparameters found during optimization\n",
+ "print(\"Best Hyperparameters:\", best) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Analyzing Hyperparameters:\n",
+ "\n",
+ "* **Batch Size**: Generally, increasing the batch size can improve training speed by utilizing hardware resources more efficiently. However, there's a limit beyond which performance degrades. You can tune the batch size within a reasonable range (e.g., 2, 4, 8, 16) to see its impact.\n",
+ "\n",
+ "* **Learning Rate**: A higher learning rate can accelerate training initially. But, a too high value can lead to unstable training and potentially slower convergence. Consider a range of learning rates (e.g., log-uniform distribution between 1e-5 and 1e-3) for exploration.\n",
+ "\n",
+ "* **Gradient Accumulation Steps**: This technique accumulates gradients over multiple batches before updating model weights. It can help reduce memory requirements but might slow down training per epoch. Experiment with different accumulation steps (e.g., 1, 2, 4) to find a balance.\n",
+ "\n",
+ "* **Optimizer Choice**: Some optimizers like Adam or SGD with momentum can be faster than others depending on the model and dataset. Explore different optimizers and their hyperparameters (e.g., momentum coefficient) to see if they lead to faster convergence.\n",
+ "\n",
+ "## Additional Considerations:\n",
+ "\n",
+ "Early Stopping: Implement early stopping to automatically terminate training if the validation loss doesn't improve for a certain number of epochs. This can save training time if the model starts overfitting.\n",
+ "Warmup Steps: A gradual increase in the learning rate during the initial training phase (warmup steps) can improve stability and potentially accelerate convergence compared to a fixed learning rate from the beginning.\n",
+ "\n",
+ "\n",
+ "* Experimentation and Profiling:\n",
+ "\n",
+ "The best hyperparameters for faster training depend on your specific model, dataset, and hardware. You'll need to experiment with different configurations using tools like Hyperopt to find the optimal settings.\n",
+ "Consider using profiling tools to identify bottlenecks in your training pipeline. This can help you focus on optimizing specific parts of the training process that are most time-consuming.\n",
+ "By analyzing these hyperparameters and implementing techniques like early stopping and warmup steps, you can potentially achieve faster fine-tuning while maintaining good model performance."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Method 1 Optuna"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from optuna import create_study, Trial\n",
+ "import time # Assuming you can use time.time() to measure training time\n",
+ "\n",
+ "# Define search space with additional hyperparameter\n",
+ "search_space = {\n",
+ " \"learning_rate\": [1e-5, 5e-5, 1e-4, 2e-4],\n",
+ " \"per_device_train_batch_size\": [2, 4, 8],\n",
+ " \"lora_alpha\": [8, 16, 32],\n",
+ " \"gradient_accumulation_steps\": [1, 2, 4, 8], # Added gradient accumulation steps\n",
+ "}\n",
+ "\n",
+ "def objective(trial):\n",
+ " # Set hyperparameters based on trial values\n",
+ " config[\"training_config\"][\"learning_rate\"] = trial.suggest_float(\"learning_rate\", search_space[\"learning_rate\"][0], search_space[\"learning_rate\"][-1])\n",
+ " config[\"training_config\"][\"per_device_train_batch_size\"] = trial.suggest_int(\"per_device_train_batch_size\", search_space[\"per_device_train_batch_size\"][0], search_space[\"per_device_train_batch_size\"][-1])\n",
+ " config[\"training_config\"][\"gradient_accumulation_steps\"] = trial.suggest_int(\"gradient_accumulation_steps\", search_space[\"gradient_accumulation_steps\"][0], search_space[\"gradient_accumulation_steps\"][-1])\n",
+ " config[\"lora_config\"][\"lora_alpha\"] = trial.suggest_int(\"lora_alpha\", search_space[\"lora_alpha\"][0], search_space[\"lora_alpha\"][-1])\n",
+ "\n",
+ " # Train the model with the current hyperparameters\n",
+ " start_time = time.time()\n",
+ " try:\n",
+ " trainer_stats = trainer_test.train()\n",
+ " training_time = time.time() - start_time\n",
+ " return training_time # Minimize training time\n",
+ " except Exception as e:\n",
+ " return float(\"inf\") # Assign a high value if training fails\n",
+ "\n",
+ "study = create_study(direction=\"minimize\")\n",
+ "study.optimize(objective, n_trials=2) # Adjust the number of trials\n",
+ "\n",
+ "# Access the best trial and its hyperparameters after optimization\n",
+ "best_trial = study.best_trial\n",
+ "best_params = best_trial.params\n",
+ "\n",
+ "print(\"Best Trial:\", best_trial.number)\n",
+ "print(\"Best Hyperparameters (Likely Fastest):\", best_params)\n",
+ "print(\"Best Training Time:\", best_trial.value, \"seconds\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "L4",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3.10",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "0005f2d9fe1e4cc98ea58b0c2868b433": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_45c1d5b0df0e420a87f791dd4cf0e425",
+ "max": 100,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_9ed49f1a099846a3a65cd6608bafb0e4",
+ "value": 100
+ }
+ },
+ "0058ed544fed4272848a891a68b9adc0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "00eea4b0c6e44c62900ea8e7d919efe9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "02fc530028ea4d538b7f6b48463ae700": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "058b2b9959b84b6f9f5d3862ef53d029": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7807f312425b4f4d9249aa1ac77d7461",
+ "placeholder": "",
+ "style": "IPY_MODEL_d8e7ea9552a84b8284b31d77090b54af",
+ "value": "Map (num_proc=2): 100%"
+ }
+ },
+ "0f55ae30c2704632941cca4727c1c4f2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "11dc1dcf6b29471580c32c818fa41d88": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_9344b22940c64654a82bb2ce06530e30",
+ "IPY_MODEL_4f68a26f64e844c7be21cc180eb6c1a2",
+ "IPY_MODEL_769b40273bab41af8eb66e494b613241"
+ ],
+ "layout": "IPY_MODEL_320c09781518483e82defa86c28316d1"
+ }
+ },
+ "1634ba52355b4681a913039666926f85": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_eff94d2d010e4b4f93a6dfcb61103a52",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_da5cd094aaae45f4a0ca051ad5babd78",
+ "value": 18
+ }
+ },
+ "1850ab17bafd4a43b5ab5899d1875a40": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "1a72b512e1374e67a858edf2844fc157": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_284192f01a924f87afd8b5087ca9af6c",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_273bf76f74bc4fb492ccb67d9e202f7b",
+ "value": 18
+ }
+ },
+ "217ca5cd404d4756a399fba3aa4fbc15": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_8f88a5b04723482ea430679e504c65f9",
+ "placeholder": "",
+ "style": "IPY_MODEL_8d153f070a8d4ad1b32996a9fd82beda",
+ "value": " 18/18 [00:00<00:00, 9.43it/s]"
+ }
+ },
+ "22ea45365d21439fb5069974bbe69711": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "23a71f8847e647daba35e495706fc846": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_22ea45365d21439fb5069974bbe69711",
+ "placeholder": "",
+ "style": "IPY_MODEL_bd087d0aa3214c5dbecc9b0bd4d976df",
+ "value": "Resolving data files: 100%"
+ }
+ },
+ "273bf76f74bc4fb492ccb67d9e202f7b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "284192f01a924f87afd8b5087ca9af6c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2c5564fb033346afbe7692a24a52b302": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "31a203cdd2f54cda8a05214844888156": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "320c09781518483e82defa86c28316d1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "32cff795f8bc490dbf63ed130e1f581f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "33fb10908c23457aa4796626102fc8c5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "341dca5ac74348dd9b5a347e38fa0b40": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3564e3cf0fe84281838d84525794e735": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_912164947c5847908424f3e60c5adb64",
+ "IPY_MODEL_7517ce80636040e29665a9353afab183",
+ "IPY_MODEL_e14b9d980a1a41fb9e81385cb0f73d3a"
+ ],
+ "layout": "IPY_MODEL_ada78aafba3f47ab8eb45cf3c83a6805"
+ }
+ },
+ "37803098ceed4528bb690ebee028c840": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "39d3b72ab6214bcf9b0bb6b6294e957c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3a97281be4c1433aa3abe6c25b7113e2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_4e19e78059b842a5832ccae2f765a30c",
+ "IPY_MODEL_1a72b512e1374e67a858edf2844fc157",
+ "IPY_MODEL_c9cfd66b68a1437d946c83163fa877df"
+ ],
+ "layout": "IPY_MODEL_cccd970273ae43d2a6e60ac421bdc882"
+ }
+ },
+ "3f7afd4bd28842cbb73e62c155667030": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9a5fd3a68fd1445f92bea51a7fec3e6b",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_37803098ceed4528bb690ebee028c840",
+ "value": 18
+ }
+ },
+ "44f189b81bbd48ca8cb146ead641d2b5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e903140c8c794c48b231924d3975b7a6",
+ "placeholder": "",
+ "style": "IPY_MODEL_7e74d789c82747e0b5066a00b9e36c1d",
+ "value": " 100/100 [00:00<00:00, 125.88 examples/s]"
+ }
+ },
+ "45b3259e3cac4de8bd19d12f07de2adb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "45c1d5b0df0e420a87f791dd4cf0e425": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4a0426a353ca41cba39d4dfeba925451": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4e19e78059b842a5832ccae2f765a30c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_32cff795f8bc490dbf63ed130e1f581f",
+ "placeholder": "",
+ "style": "IPY_MODEL_4a0426a353ca41cba39d4dfeba925451",
+ "value": "Resolving data files: 100%"
+ }
+ },
+ "4f68a26f64e844c7be21cc180eb6c1a2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_341dca5ac74348dd9b5a347e38fa0b40",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_8ba6fd1bf16a4680b8a8c9c55ecf23e7",
+ "value": 18
+ }
+ },
+ "51a6d3c97480476e8c22d9ad670bdc47": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "53ee8f5e8b7d4076bdb0167baf2e5729": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "58b932a03b2c4aa4891d541f186244b9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "5d1fbd3c62d94df7befdefc451221414": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_8ad6abb48f38469f9d399eea8f5e5b70",
+ "IPY_MODEL_6cea0da24cf54811a43168c606759bab",
+ "IPY_MODEL_eb8c88f5c06c49fe9099371b3cf112ae"
+ ],
+ "layout": "IPY_MODEL_89a1354722e640758978befc06ed4a78"
+ }
+ },
+ "64539b4212fe4d989976f56369bb746b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "67b9a3505ae644dbb3c4fc14781a2731": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_53ee8f5e8b7d4076bdb0167baf2e5729",
+ "max": 100,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_d70fd9035f9b4d82892fae34c28c46d5",
+ "value": 100
+ }
+ },
+ "696e82ec6a174974a90d5abc7c101ee7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6cea0da24cf54811a43168c606759bab": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dade882aca304a31b693a2c58807d825",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_02fc530028ea4d538b7f6b48463ae700",
+ "value": 18
+ }
+ },
+ "72eca1e2871b458abd3383d9711215a2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_058b2b9959b84b6f9f5d3862ef53d029",
+ "IPY_MODEL_85d4879bd7d64766905db34cef052fed",
+ "IPY_MODEL_44f189b81bbd48ca8cb146ead641d2b5"
+ ],
+ "layout": "IPY_MODEL_f89c5c949e984361bce7f97d86d2a2e5"
+ }
+ },
+ "734b6d3e3406403293c4bc955a643528": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dc3b2edc3f5d480a93b57b15b4444608",
+ "placeholder": "",
+ "style": "IPY_MODEL_7967d420aff1414e9fe53eb04c928eb4",
+ "value": "Map: 100%"
+ }
+ },
+ "7517ce80636040e29665a9353afab183": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bb078c8c1f6a48359dc654d91ece684d",
+ "max": 18,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_9b9322336b564a409086955ebda07fc3",
+ "value": 18
+ }
+ },
+ "769b40273bab41af8eb66e494b613241": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dc85f5e365f4488fa185d0ae35fde806",
+ "placeholder": "",
+ "style": "IPY_MODEL_51a6d3c97480476e8c22d9ad670bdc47",
+ "value": " 18/18 [00:00<00:00, 1567.70it/s]"
+ }
+ },
+ "7807f312425b4f4d9249aa1ac77d7461": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "793f49f397b54daab63194cee8d04256": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7967d420aff1414e9fe53eb04c928eb4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7e11cccce8be49008f8db3a0c3ea603d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7e74d789c82747e0b5066a00b9e36c1d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "82c6c2752a0746f3935e069c0f8811d6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "85d4879bd7d64766905db34cef052fed": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0058ed544fed4272848a891a68b9adc0",
+ "max": 100,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_33fb10908c23457aa4796626102fc8c5",
+ "value": 100
+ }
+ },
+ "89a1354722e640758978befc06ed4a78": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8a195771bdc0462e8f9fbb60eb9141b1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "8a8d3a006ee24c4393d7c2f2d040ce52": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "8ad6abb48f38469f9d399eea8f5e5b70": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_39d3b72ab6214bcf9b0bb6b6294e957c",
+ "placeholder": "",
+ "style": "IPY_MODEL_696e82ec6a174974a90d5abc7c101ee7",
+ "value": "Resolving data files: 100%"
+ }
+ },
+ "8ba6fd1bf16a4680b8a8c9c55ecf23e7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "8d153f070a8d4ad1b32996a9fd82beda": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "8f88a5b04723482ea430679e504c65f9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "912164947c5847908424f3e60c5adb64": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ff108c92fb5547869ee545cf9a094b07",
+ "placeholder": "",
+ "style": "IPY_MODEL_2c5564fb033346afbe7692a24a52b302",
+ "value": "Loading dataset shards: 100%"
+ }
+ },
+ "9344b22940c64654a82bb2ce06530e30": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_793f49f397b54daab63194cee8d04256",
+ "placeholder": "",
+ "style": "IPY_MODEL_fa79cfa23f3a430dab69a59d93383cd0",
+ "value": "Resolving data files: 100%"
+ }
+ },
+ "963c0aa5620b4ea8b5a903894646121c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9a5fd3a68fd1445f92bea51a7fec3e6b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9b9322336b564a409086955ebda07fc3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "9bceb9eddb2147c1abbf3391c70e6784": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9ed49f1a099846a3a65cd6608bafb0e4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "9f91f7ce62e243f59d72e5ba36f97b8f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_af0233735d744b7e838f50f52c9d6cbe",
+ "placeholder": "",
+ "style": "IPY_MODEL_8a8d3a006ee24c4393d7c2f2d040ce52",
+ "value": "Loading dataset shards: 100%"
+ }
+ },
+ "a419499622cd4374937423a79677298f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b93514308ae44afbb1a0511f5f9c6ddf",
+ "placeholder": "",
+ "style": "IPY_MODEL_58b932a03b2c4aa4891d541f186244b9",
+ "value": " 18/18 [00:00<00:00, 1458.49it/s]"
+ }
+ },
+ "ada78aafba3f47ab8eb45cf3c83a6805": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "af0096de28414303ba5324f4087cd92e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "af0233735d744b7e838f50f52c9d6cbe": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b7e7896aeac74b6eae27de0677100e57": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b8b277831f1a45109b3a4a3565fbdb9d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_9f91f7ce62e243f59d72e5ba36f97b8f",
+ "IPY_MODEL_1634ba52355b4681a913039666926f85",
+ "IPY_MODEL_217ca5cd404d4756a399fba3aa4fbc15"
+ ],
+ "layout": "IPY_MODEL_bc6d92cb8837428bb7038d75e6af604e"
+ }
+ },
+ "b93514308ae44afbb1a0511f5f9c6ddf": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bb078c8c1f6a48359dc654d91ece684d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bb1156b7d349440d9cc8a2f0328465a7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_23a71f8847e647daba35e495706fc846",
+ "IPY_MODEL_3f7afd4bd28842cbb73e62c155667030",
+ "IPY_MODEL_a419499622cd4374937423a79677298f"
+ ],
+ "layout": "IPY_MODEL_64539b4212fe4d989976f56369bb746b"
+ }
+ },
+ "bc6d92cb8837428bb7038d75e6af604e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bd087d0aa3214c5dbecc9b0bd4d976df": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "be6162f66e594d3ebd8c53ebab3bbfa6": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_963c0aa5620b4ea8b5a903894646121c",
+ "placeholder": "",
+ "style": "IPY_MODEL_31a203cdd2f54cda8a05214844888156",
+ "value": " 100/100 [00:00<00:00, 5440.44 examples/s]"
+ }
+ },
+ "c4d39c87c16c4961b942d896742ff7ce": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_af0096de28414303ba5324f4087cd92e",
+ "placeholder": "",
+ "style": "IPY_MODEL_0f55ae30c2704632941cca4727c1c4f2",
+ "value": " 100/100 [00:01<00:00, 113.55 examples/s]"
+ }
+ },
+ "c9cfd66b68a1437d946c83163fa877df": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_45b3259e3cac4de8bd19d12f07de2adb",
+ "placeholder": "",
+ "style": "IPY_MODEL_b7e7896aeac74b6eae27de0677100e57",
+ "value": " 18/18 [00:00<00:00, 1.32it/s]"
+ }
+ },
+ "cccd970273ae43d2a6e60ac421bdc882": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d70fd9035f9b4d82892fae34c28c46d5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "d8e7ea9552a84b8284b31d77090b54af": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "da5cd094aaae45f4a0ca051ad5babd78": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "dade882aca304a31b693a2c58807d825": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dc3b2edc3f5d480a93b57b15b4444608": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dc85f5e365f4488fa185d0ae35fde806": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e14b9d980a1a41fb9e81385cb0f73d3a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9bceb9eddb2147c1abbf3391c70e6784",
+ "placeholder": "",
+ "style": "IPY_MODEL_8a195771bdc0462e8f9fbb60eb9141b1",
+ "value": " 18/18 [00:35<00:00, 1.20it/s]"
+ }
+ },
+ "e257e4a2bfdb48038102173d397ab2e4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_82c6c2752a0746f3935e069c0f8811d6",
+ "placeholder": "",
+ "style": "IPY_MODEL_1850ab17bafd4a43b5ab5899d1875a40",
+ "value": "Map (num_proc=2): 100%"
+ }
+ },
+ "e3bd7f85ce194cd4b697c2eb82038658": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_734b6d3e3406403293c4bc955a643528",
+ "IPY_MODEL_0005f2d9fe1e4cc98ea58b0c2868b433",
+ "IPY_MODEL_be6162f66e594d3ebd8c53ebab3bbfa6"
+ ],
+ "layout": "IPY_MODEL_7e11cccce8be49008f8db3a0c3ea603d"
+ }
+ },
+ "e5880b946aae4b84a94226a5d6acaf45": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e903140c8c794c48b231924d3975b7a6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "eb8c88f5c06c49fe9099371b3cf112ae": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_00eea4b0c6e44c62900ea8e7d919efe9",
+ "placeholder": "",
+ "style": "IPY_MODEL_fe17bedb5ef04d8b9e064fa1e0d75185",
+ "value": " 18/18 [00:00<00:00, 1.42it/s]"
+ }
+ },
+ "eff94d2d010e4b4f93a6dfcb61103a52": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f89c5c949e984361bce7f97d86d2a2e5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fa79cfa23f3a430dab69a59d93383cd0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "fe17bedb5ef04d8b9e064fa1e0d75185": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ff108c92fb5547869ee545cf9a094b07": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ffa74977e7464cebb16d3cf8ee976d51": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e257e4a2bfdb48038102173d397ab2e4",
+ "IPY_MODEL_67b9a3505ae644dbb3c4fc14781a2731",
+ "IPY_MODEL_c4d39c87c16c4961b942d896742ff7ce"
+ ],
+ "layout": "IPY_MODEL_e5880b946aae4b84a94226a5d6acaf45"
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/ai-medical-chatbot-master/6-FineTunning/MultipleGpu.ipynb b/ai-medical-chatbot-master/6-FineTunning/MultipleGpu.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..a1867627d5bbea89a162f8951574e8dc6957fbd7
--- /dev/null
+++ b/ai-medical-chatbot-master/6-FineTunning/MultipleGpu.ipynb
@@ -0,0 +1,737 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Multiple GPUS"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from datautils import MyTrainDataset\n",
+ "import torch.multiprocessing as mp\n",
+ "from torch.utils.data.distributed import DistributedSampler\n",
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
+ "from torch.distributed import init_process_group, destroy_process_group\n",
+ "import os\n",
+ "import argparse\n",
+ "from datasets import load_dataset\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
+ "from peft import LoraConfig\n",
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "\n",
+ "def ddp_setup(rank, world_size):\n",
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ " os.environ[\"MASTER_PORT\"] = \"12355\"\n",
+ " init_process_group(backend=\"nccl\", rank=rank, world_size=world_size)\n",
+ " torch.cuda.set_device(rank)\n",
+ "\n",
+ "class Trainer:\n",
+ " def __init__(self, model, train_data, optimizer, gpu_id, save_every):\n",
+ " self.gpu_id = gpu_id\n",
+ " self.model = model.to(gpu_id)\n",
+ " self.train_data = train_data\n",
+ " self.optimizer = optimizer\n",
+ " self.save_every = save_every\n",
+ " f.model = DDP(model, device_ids=[gpu_id])\n",
+ "\n",
+ " def _run_batch(self, source, targets):\n",
+ " self.optimizer.zero_grad()\n",
+ " output = self.model(source)\n",
+ " loss = F.cross_entropy(output, targets)\n",
+ " loss.backward()\n",
+ " self.optimizer.step()\n",
+ "\n",
+ " def _run_epoch(self, epoch):\n",
+ " b_sz = len(next(iter(self.train_data))[0])\n",
+ " print(f\"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}\")\n",
+ " self.train_data.sampler.set_epoch(epoch)\n",
+ " for source, targets in self.train_data:\n",
+ " source = source.to(self.gpu_id)\n",
+ " targets = targets.to(self.gpu_id)\n",
+ " self._run_batch(source, targets)\n",
+ "\n",
+ " def _save_checkpoint(self, epoch):\n",
+ " ckp = self.model.module.state_dict()\n",
+ " PATH = \"checkpoint.pt\"\n",
+ " torch.save(ckp, PATH)\n",
+ " print(f\"Epoch {epoch} | Training checkpoint saved at {PATH}\")\n",
+ "\n",
+ " def train(self, max_epochs):\n",
+ " for epoch in range(max_epochs):\n",
+ " self._run_epoch(epoch)\n",
+ " if self.gpu_id == 0 and epoch % self.save_every == 0:\n",
+ " self._save_checkpoint(epoch)\n",
+ "\n",
+ "def load_train_objs():\n",
+ " dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
+ " dataset = load_dataset(dataset_name, split=\"train\")\n",
+ " dataset = dataset.select(range(100))\n",
+ "\n",
+ " model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ "\n",
+ " bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ " )\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " quantization_config=bnb_config,\n",
+ " trust_remote_code=True,\n",
+ " use_cache=False,\n",
+ " device_map=\"auto\",\n",
+ " )\n",
+ "\n",
+ " lora_alpha = 16\n",
+ " lora_dropout = 0.1\n",
+ " lora_r = 32\n",
+ " peft_config = LoraConfig(\n",
+ " lora_alpha=lora_alpha,\n",
+ " lora_dropout=lora_dropout,\n",
+ " r=lora_r,\n",
+ " bias=\"none\",\n",
+ " task_type=\"CAUSAL_LM\",\n",
+ " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
+ " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
+ " )\n",
+ "\n",
+ " max_seq_length = 512\n",
+ " output_dir = \"./results\"\n",
+ " per_device_train_batch_size = 2\n",
+ " gradient_accumulation_steps = 2\n",
+ " optim = \"adamw_torch\"\n",
+ " save_steps = 10\n",
+ " logging_steps = 1\n",
+ " learning_rate = 2e-4\n",
+ " max_grad_norm = 0.3\n",
+ " max_steps = 1\n",
+ " warmup_ratio = 0.1\n",
+ " lr_scheduler_type = \"cosine\"\n",
+ "\n",
+ " training_arguments = TrainingArguments(\n",
+ " output_dir=output_dir,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " optim=optim,\n",
+ " save_steps=save_steps,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " fp16=True,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " max_steps=max_steps,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=True,\n",
+ " lr_scheduler_type=lr_scheduler_type,\n",
+ " gradient_checkpointing=True,\n",
+ " )\n",
+ "\n",
+ " return dataset, model, peft_config, tokenizer, training_arguments\n",
+ "\n",
+ "def prepare_dataloader(dataset, batch_size):\n",
+ " return DataLoader(\n",
+ " dataset,\n",
+ " batch_size=batch_size,\n",
+ " pin_memory=True,\n",
+ " shuffle=False,\n",
+ " sampler=DistributedSampler(dataset),\n",
+ " )\n",
+ "\n",
+ "def main(rank, world_size, save_every, total_epochs, batch_size):\n",
+ " ddp_setup(rank, world_size)\n",
+ " dataset, model, peft_config, tokenizer, training_arguments = load_train_objs()\n",
+ " train_data = prepare_dataloader(dataset, batch_size)\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " peft_config=peft_config,\n",
+ " dataset_text_field=\"context\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " tokenizer=tokenizer,\n",
+ " args=training_arguments,\n",
+ " )\n",
+ " trainer = Trainer(model, train_data, optimizer=trainer.optimizer, gpu_id=rank, save_every=save_every)\n",
+ " trainer.train(total_epochs)\n",
+ " destroy_process_group()\n",
+ "\n",
+ "TOTAL_EPOCHS = 10\n",
+ "SAVE_EVERY = 2\n",
+ "BATCH_SIZE = 32\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " world_size = torch.cuda.device_count()\n",
+ " mp.set_start_method(\"spawn\", force=True) # Add this line\n",
+ " mp.spawn(main, args=(world_size, SAVE_EVERY, TOTAL_EPOCHS, BATCH_SIZE), nprocs=world_size)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "import torch.multiprocessing as mp\n",
+ "from torch.utils.data.distributed import DistributedSampler\n",
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
+ "from torch.distributed import init_process_group, destroy_process_group\n",
+ "import os\n",
+ "import argparse\n",
+ "from datasets import load_dataset\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
+ "from peft import LoraConfig\n",
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "def ddp_setup(rank, world_size):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " rank: Unique identifier of each process\n",
+ " world_size: Total number of processes\n",
+ " \"\"\"\n",
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ " os.environ[\"MASTER_PORT\"] = \"12355\"\n",
+ " init_process_group(backend=\"nccl\", rank=rank, world_size=world_size)\n",
+ " torch.cuda.set_device(rank)\n",
+ "\n",
+ "class Trainer:\n",
+ " def __init__(self, model, train_data, optimizer, gpu_id, save_every):\n",
+ " self.gpu_id = gpu_id\n",
+ " self.model = model.to(gpu_id)\n",
+ " self.train_data = train_data\n",
+ " self.optimizer = optimizer\n",
+ " self.save_every = save_every\n",
+ " self.model = DDP(model, device_ids=[gpu_id])\n",
+ "\n",
+ " def _run_batch(self, source, targets):\n",
+ " self.optimizer.zero_grad()\n",
+ " output = self.model(source)\n",
+ " loss = F.cross_entropy(output, targets)\n",
+ " loss.backward()\n",
+ " self.optimizer.step()\n",
+ "\n",
+ " def _run_epoch(self, epoch):\n",
+ " b_sz = len(next(iter(self.train_data))[0])\n",
+ " print(f\"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}\")\n",
+ " self.train_data.sampler.set_epoch(epoch)\n",
+ " for source, targets in self.train_data:\n",
+ " source = source.to(self.gpu_id)\n",
+ " targets = targets.to(self.gpu_id)\n",
+ " self._run_batch(source, targets)\n",
+ "\n",
+ " def _save_checkpoint(self, epoch):\n",
+ " ckp = self.model.module.statt()\n",
+ " PATH = \"checkpoint.pt\"\n",
+ " torch.save(ckp, PATH)\n",
+ " print(f\"Epoch {epoch} | Training checkpoint saved at {PATH}\")\n",
+ "\n",
+ " def train(self, max_epochs):\n",
+ " for epoch in range(max_epochs):\n",
+ " self._run_epoch(epoch)\n",
+ " if self.gpu_id == 0 and epoch % self.save_every == 0:\n",
+ " self._save_checkpoint(epoch)\n",
+ "\n",
+ "def load_train_objs():\n",
+ " dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
+ " dataset = load_dataset(dataset_name, split=\"train\")\n",
+ " dataset = dataset.select(range(100))\n",
+ "\n",
+ " model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ "\n",
+ " bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ " )\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " quantization_config=bnb_config,\n",
+ " trust_remote_code=True,\n",
+ " use_cache=False,\n",
+ " device_map=\"auto\",\n",
+ " )\n",
+ "\n",
+ " lora_alpha = 16\n",
+ " lora_dropout = 0.1\n",
+ " lora_r = 32\n",
+ " peft_config = LoraConfig(\n",
+ " lora_alpha=lora_alpha,\n",
+ " lora_dropout=lora_dropout,\n",
+ " r=lora_r,\n",
+ " bias=\"none\",\n",
+ " task_type=\"CAUSAL_LM\",\n",
+ " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
+ " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
+ " )\n",
+ "\n",
+ " max_seq_length = 512\n",
+ " output_dir = \"./results\"\n",
+ " per_device_train_batch_size = 2\n",
+ " gradient_accumulation_steps = 2\n",
+ " optim = \"adamw_torch\"\n",
+ " save_steps = 10\n",
+ " logging_steps = 1\n",
+ " learning_rate = 2e-4\n",
+ " max_grad_norm = 0.3\n",
+ " max_steps = 1\n",
+ " warmup_ratio = 0.1\n",
+ " lr_scheduler_type = \"cosine\"\n",
+ "\n",
+ " training_arguments = TrainingArguments(\n",
+ " output_dir=output_dir,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " optim=optim,\n",
+ " save_steps=save_steps,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " fp16=True,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " max_steps=max_steps,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=True,\n",
+ " lr_scheduler_type=lr_scheduler_type,\n",
+ " gradient_checkpointing=True,\n",
+ " )\n",
+ "\n",
+ " return dataset, model, peft_config, tokenizer, training_arguments\n",
+ "\n",
+ "def prepare_dataloader(dataset, batch_size):\n",
+ " return DataLoader(\n",
+ " dataset,\n",
+ " batch_size=batch_size,\n",
+ " pin_memory=True,\n",
+ " shuffle=False,\n",
+ " sampler=DistributedSampler(dataset),\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.multiprocessing as mp\n",
+ "\n",
+ "def main(rank, world_size):\n",
+ " # Define the parameters as constants\n",
+ " TOTAL_EPOCHS = 10\n",
+ " SAVE_EVERY = 2\n",
+ " BATCH_SIZE = 32\n",
+ " torch.cuda.init()\n",
+ " ddp_setup(rank, world_size) \n",
+ " dataset, model, peft_config, tokenizer, training_arguments = load_train_objs()\n",
+ " train_data = prepare_dataloader(dataset, BATCH_SIZE) # Corrected batch_size variable\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " peft_config=peft_config,\n",
+ " dataset_text_field=\"context\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " tokenizer=tokenizer,\n",
+ " args=training_arguments,\n",
+ " )\n",
+ " trainer = Trainer(model, train_data, optimizer=trainer.optimizer, gpu_id=rank, save_every=SAVE_EVERY)\n",
+ " trainer.train(TOTAL_EPOCHS)\n",
+ " destroy_process_group()\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " mp.set_start_method('spawn') # Set start method to 'spawn'\n",
+ " world_size = torch.cuda.device_count()\n",
+ "\n",
+ " # Workaround for Jupyter Notebook and interactive environments\n",
+ " processes = []\n",
+ " for rank in range(world_size):\n",
+ " p = mp.Process(target=main, args=(rank, world_size))\n",
+ " p.start()\n",
+ " processes.append(p)\n",
+ "\n",
+ " for p in processes:\n",
+ " p.join()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
+ "from datasets import load_dataset\n",
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "import os\n",
+ "import socket\n",
+ "\n",
+ "# Distributed training setup (assuming all GPUs are available on a single machine)\n",
+ "def init_distributed(rank, world_size):\n",
+ " \"\"\"Initializes distributed training using `nccl` backend.\"\"\"\n",
+ " if rank == 0:\n",
+ " os.environ[\"MASTER_ADDR\"] = socket.gethostname() # Set MASTER_ADDR using rank 0's hostname\n",
+ " else:\n",
+ " # Wait a bit to ensure MASTER_ADDR is set before other ranks try to use it\n",
+ " import time\n",
+ " time.sleep(5)\n",
+ " os.environ[\"MASTER_PORT\"] = \"12345\" # Set MASTER_PORT environment variable\n",
+ " os.environ[\"RANK\"] = str(rank) # Set RANK environment variable\n",
+ " os.environ[\"WORLD_SIZE\"] = str(world_size) # Set WORLD_SIZE environment variable\n",
+ " torch.distributed.init_process_group(backend='nccl', init_method='env://')\n",
+ "\n",
+ "# Cleanup after training\n",
+ "def cleanup_distributed():\n",
+ " if torch.distributed.is_initialized():\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "# Model and tokenizer selection\n",
+ "model_name = \"facebook/bart-base\" # Replace with your desired model\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
+ "\n",
+ "# Dataset loading (replace with your dataset and field names)\n",
+ "dataset = load_dataset(\"glue\", \"mnli\", split=\"train\")\n",
+ "text_field = \"premise\" # Assuming premise is the field containing text for prediction\n",
+ "\n",
+ "# Training arguments (adjust hyperparameters as needed)\n",
+ "training_args = TrainingArguments(\n",
+ " output_dir=\"./results\",\n",
+ " per_device_train_batch_size=2, # Adjust based on GPU memory (might need to adjust)\n",
+ " save_steps=500,\n",
+ " save_total_limit=2,\n",
+ " num_train_epochs=3, # Adjust training time as needed\n",
+ ")\n",
+ "\n",
+ "world_size = torch.cuda.device_count()\n",
+ "if world_size > 1:\n",
+ " # Initialize distributed training\n",
+ " init_distributed(rank=0, world_size=world_size) # Rank is assumed to be 0 here\n",
+ "\n",
+ " # Wrap model in DDP for distributed training\n",
+ " model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])\n",
+ "\n",
+ " # Create SFTTrainer with distributed settings\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " args=training_args, # Pass training_args as 'args' instead of 'training_args'\n",
+ " train_dataset=dataset,\n",
+ " dataset_text_field=text_field,\n",
+ " compute_metrics=None, # You can define your custom metrics here\n",
+ " )\n",
+ " print(\"Trainer For distributed training loaded\")\n",
+ "else:\n",
+ " # For single-GPU training\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " args=training_args, # Pass training_args as 'args' instead of 'training_args'\n",
+ " train_dataset=dataset,\n",
+ " dataset_text_field=text_field,\n",
+ " compute_metrics=None, # You can define your custom metrics here\n",
+ " )\n",
+ " print(\"Trainer For single-GPU loaded\")\n",
+ "\n",
+ "# Start training\n",
+ "trainer.train()\n",
+ "\n",
+ "# Cleanup after training\n",
+ "cleanup_distributed()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import torch.multiprocessing as mp\n",
+ "from datasets import load_dataset\n",
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
+ "from peft import LoraConfig\n",
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
+ "\n",
+ "\n",
+ "# Distributed training setup\n",
+ "def init_distributed():\n",
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ " os.environ[\"MASTER_PORT\"] = \"12345\"\n",
+ " torch.distributed.init_process_group(backend='nccl', world_size=torch.cuda.device_count(), rank=rank)\n",
+ "\n",
+ "def cleanup_distributed():\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "def main_worker(rank, world_size):\n",
+ " init_distributed()\n",
+ "\n",
+ " # Your model training and fine-tuning code goes here\n",
+ " # Load the dataset\n",
+ " dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
+ " dataset = load_dataset(dataset_name, split=\"train\")\n",
+ " # Select the first 1M rows of the dataset\n",
+ " dataset = dataset.select(range(100))\n",
+ "\n",
+ " # Load the model + tokenizer\n",
+ " model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ " bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ " )\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " quantization_config=bnb_config,\n",
+ " trust_remote_code=True,\n",
+ " use_cache=False,\n",
+ " )\n",
+ "\n",
+ " # Check for available GPUs\n",
+ " device = torch.device(f\"cuda:{rank}\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ " # PEFT config\n",
+ " lora_alpha = 1\n",
+ " lora_dropout = 0.1\n",
+ " lora_r = 32 # 64\n",
+ " peft_config = LoraConfig(\n",
+ " lora_alpha=lora_alpha,\n",
+ " lora_dropout=lora_dropout,\n",
+ " task_type=\"CAUSAL_LM\",\n",
+ " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
+ " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
+ " )\n",
+ "\n",
+ " # Args\n",
+ " max_seq_length = 512\n",
+ " output_dir = \"./results\"\n",
+ " per_device_train_batch_size = 2 # reduced batch size to avoid OOM\n",
+ " gradient_accumulation_steps = 2\n",
+ " optim = \"adamw_torch\"\n",
+ " save_steps = 10\n",
+ " logging_steps = 1\n",
+ " learning_rate = 2e-4\n",
+ " max_grad_norm = 0.3\n",
+ " max_steps = 1 # 300 Approx the size of guanaco at bs 8, ga 2, 2 GPUs.\n",
+ " warmup_ratio = 0.1\n",
+ " lr_scheduler_type = \"cosine\"\n",
+ " training_arguments = TrainingArguments(\n",
+ " output_dir=output_dir,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " optim=optim,\n",
+ " save_steps=save_steps,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " fp16=True,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " max_steps=max_steps,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=True,\n",
+ " lr_scheduler_type=lr_scheduler_type,\n",
+ " gradient_checkpointing=True, # gradient checkpointing\n",
+ " #report_to=\"wandb\",\n",
+ " )\n",
+ "\n",
+ " # Trainer\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " peft_config=peft_config,\n",
+ " dataset_text_field=\"context\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " tokenizer=tokenizer,\n",
+ " args=training_arguments,\n",
+ " )\n",
+ "\n",
+ " # Train :)\n",
+ " trainer.train()\n",
+ " cleanup_distributed()\n",
+ "\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " world_size = torch.cuda.device_count()\n",
+ " mp.set_start_method('spawn') # Add this line to fix the error\n",
+ " processes = []\n",
+ " for rank in range(world_size):\n",
+ " p = mp.Process(target=main_worker, args=(rank, world_size))\n",
+ " p.start()\n",
+ " processes.append(p)\n",
+ " for p in processes:\n",
+ " p.join()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def finetune():\n",
+ " from datasets import load_dataset\n",
+ " import torch\n",
+ " from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
+ " from peft import LoraConfig\n",
+ " from trl import SFTTrainer\n",
+ " from transformers import TrainingArguments\n",
+ " from torch.nn.parallel import DistributedDataParallel as DDP\n",
+ " # Load the dataset\n",
+ " dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
+ " dataset = load_dataset(dataset_name, split=\"train\")\n",
+ " # Select the first 1M rows of the dataset\n",
+ " dataset = dataset.select(range(100))\n",
+ " # Load the model + tokenizer\n",
+ " model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ " bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ " )\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " quantization_config=bnb_config,\n",
+ " trust_remote_code=True,\n",
+ " use_cache=False,\n",
+ " )\n",
+ " # Check for available GPUs\n",
+ " if torch.cuda.device_count() > 1:\n",
+ " print(\"Multiple GPUs detected, enabling DataParallel...\")\n",
+ " model = DDP(model) # Wrap the model with DDP\n",
+ " else:\n",
+ " print(\"Using single GPU...\")\n",
+ " # PEFT config\n",
+ " lora_alpha = 16\n",
+ " lora_dropout = 0.1\n",
+ " lora_r = 32 # 64\n",
+ " peft_config = LoraConfig(\n",
+ " lora_alpha=lora_alpha,\n",
+ " lora_dropout=lora_dropout,\n",
+ " r=lora_r,\n",
+ " bias=\"none\",\n",
+ " task_type=\"CAUSAL_LM\",\n",
+ " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
+ " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
+ " )\n",
+ " # Args\n",
+ " max_seq_length = 512\n",
+ " output_dir = \"./results\"\n",
+ " per_device_train_batch_size = 2 # reduced batch size to avoid OOM\n",
+ " gradient_accumulation_steps = 2\n",
+ " optim = \"adamw_torch\"\n",
+ " save_steps = 10\n",
+ " logging_steps = 1\n",
+ " learning_rate = 2e-4\n",
+ " max_grad_norm = 0.3\n",
+ " max_steps = 1 # 300 Approx the size of guanaco at bs 8, ga 2, 2 GPUs.\n",
+ " warmup_ratio = 0.1\n",
+ " lr_scheduler_type = \"cosine\"\n",
+ "\n",
+ " training_arguments = TrainingArguments(\n",
+ " output_dir=output_dir,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " optim=optim,\n",
+ " save_steps=save_steps,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " fp16=True,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " max_steps=max_steps,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=True,\n",
+ " lr_scheduler_type=lr_scheduler_type,\n",
+ " gradient_checkpointing=True, # gradient checkpointing\n",
+ " #report_to=\"wandb\",\n",
+ " )\n",
+ " # Trainer\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " peft_config=peft_config,\n",
+ " dataset_text_field=\"context\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " tokenizer=tokenizer,\n",
+ " args=training_arguments,\n",
+ " )\n",
+ " # Train :)\n",
+ " trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import torch.multiprocessing as mp\n",
+ "\n",
+ "def init_distributed(rank, world_size, local_rank=0): # Add local_rank argument\n",
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ " os.environ[\"MASTER_PORT\"] = \"12345\" # Adjust port if needed\n",
+ " if rank == 0:\n",
+ " print(\"Initializing distributed process group...\")\n",
+ " torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank)\n",
+ " torch.cuda.set_device(local_rank) # Set unique GPU device for each process\n",
+ "\n",
+ "def cleanup_distributed():\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "def main_worker(rank, world_size):\n",
+ " local_rank = rank % torch.cuda.device_count() # Assign unique local rank\n",
+ " init_distributed(rank, world_size, local_rank)\n",
+ " # Your model training and fine-tuning code goes here with model on local_rank GPU\n",
+ " finetune() # Move model to assigned GPU\n",
+ " cleanup_distributed()\n",
+ "if __name__ == \"__main__\":\n",
+ " world_size = torch.cuda.device_count()\n",
+ "\n",
+ " # Workaround for Jupyter Notebook and interactive environments\n",
+ " processes = []\n",
+ " for rank in range(world_size):\n",
+ " p = mp.Process(target=main_worker, args=(rank, world_size))\n",
+ " p.start()\n",
+ " processes.append(p)\n",
+ "\n",
+ " for p in processes:\n",
+ " p.join()\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/ai-medical-chatbot-master/6-FineTunning/Parallel-FineTuning.ipynb b/ai-medical-chatbot-master/6-FineTunning/Parallel-FineTuning.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..7e2e8c2b18b65cb3472e1a199d668f558e60440b
--- /dev/null
+++ b/ai-medical-chatbot-master/6-FineTunning/Parallel-FineTuning.ipynb
@@ -0,0 +1,136 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "import os\n",
+ "import torch\n",
+ "import torch.multiprocessing as mp\n",
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, TrainingArguments\n",
+ "from datasets import load_dataset\n",
+ "from trl import SFTTrainer\n",
+ "from peft import LoraConfig\n",
+ "\n",
+ "def init_distributed(rank, world_size):\n",
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ " os.environ[\"MASTER_PORT\"] = \"12345\"\n",
+ " if rank == 0:\n",
+ " print(\"Initializing distributed process group...\")\n",
+ " torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank)\n",
+ "\n",
+ "def cleanup_distributed():\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "def main_worker(rank, world_size):\n",
+ " init_distributed(rank, world_size)\n",
+ "\n",
+ " # Move the finetune() function here\n",
+ " # Load the dataset\n",
+ " dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
+ " dataset = load_dataset(dataset_name, split=\"train\")\n",
+ " # Select the first 1M rows of the dataset\n",
+ " dataset = dataset.select(range(100))\n",
+ " # Load the model + tokenizer\n",
+ " model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ " bnb_config = BitsAndBytesConfig(\n",
+ " load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16,\n",
+ " )\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " quantization_config=bnb_config,\n",
+ " trust_remote_code=True,\n",
+ " use_cache=False,\n",
+ " )\n",
+ " # Replace the DDP wrapping part with the following lines\n",
+ " model = model.to(rank)\n",
+ " model = DDP(model, device_ids=[rank], output_device=rank)\n",
+ "\n",
+ " # PEFT config\n",
+ " lora_alpha = 16\n",
+ " lora_dropout = 0.1\n",
+ " lora_r = 32 # 64\n",
+ " peft_config = LoraConfig(\n",
+ " lora_alpha=lora_alpha,\n",
+ " lora_dropout=lora_dropout,\n",
+ " r=lora_r,\n",
+ " bias=\"none\",\n",
+ " task_type=\"CAUSAL_LM\",\n",
+ " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
+ " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
+ " )\n",
+ " # Args\n",
+ " max_seq_length = 512\n",
+ " output_dir = \"./results\"\n",
+ " per_device_train_batch_size = 2 # reduced batch size to avoid OOM\n",
+ " gradient_accumulation_steps = 2\n",
+ " optim = \"adamw_torch\"\n",
+ " save_steps = 10\n",
+ " logging_steps = 1\n",
+ " learning_rate = 2e-4\n",
+ " max_grad_norm = 0.3\n",
+ " max_steps = 1 # 300 Approx the size of guanaco at bs 8, ga 2, 2 GPUs.\n",
+ " warmup_ratio = 0.1\n",
+ " lr_scheduler_type = \"cosine\"\n",
+ "\n",
+ " training_arguments = TrainingArguments(\n",
+ " output_dir=output_dir,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " optim=optim,\n",
+ " save_steps=save_steps,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " fp16=True,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " max_steps=max_steps,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=True,\n",
+ " lr_scheduler_type=lr_scheduler_type,\n",
+ " gradient_checkpointing=True, # gradient checkpointing\n",
+ " report_to=\"wandb\",\n",
+ " )\n",
+ " # Trainer\n",
+ " trainer = SFTTrainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " peft_config=peft_config,\n",
+ " dataset_text_field=\"context\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " tokenizer=tokenizer,\n",
+ " args=training_arguments,\n",
+ " )\n",
+ " # Train :)\n",
+ " trainer.train()\n",
+ " cleanup_distributed()\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " world_size = torch.cuda.device_count()\n",
+ "\n",
+ " processes = []\n",
+ " for rank in range(world_size):\n",
+ " p = mp.Process(target=main_worker, args=(rank, world_size))\n",
+ " p.start()\n",
+ " processes.append(p)\n",
+ "\n",
+ " for p in processes:\n",
+ " p.join()\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/ai-medical-chatbot-master/6-FineTunning/README.md b/ai-medical-chatbot-master/6-FineTunning/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ai-medical-chatbot-master/7-Multimodal/README.md b/ai-medical-chatbot-master/7-Multimodal/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c265f463953d2ca23742055f832a9d1def1bb7fb
--- /dev/null
+++ b/ai-medical-chatbot-master/7-Multimodal/README.md
@@ -0,0 +1,81 @@
+# Multimodal Medical Chatbot
+We are interested to build a medical chatbot that having a picture we can give you symptoms and solutions.
+
+First wer need to build our dataset.
+
+
+
+
+# Medical-Images Datasets
+
+* A list of Medical imaging datasets. Source : https://sites.google.com/site/aacruzr/image-datasets
+* An additional, possibly overlapping list can be found at : https://github.com/beamandrew/medical-data
+
+### Multimodal databases
+
+* Center for Invivo Microscopy (CIVM), Embrionic and Neonatal Mouse (H&E, MR) http://www.civm.duhs.duke.edu/devatlas/
+user guide: http://www.civm.duhs.duke.edu/devatlas/UserGuide.pdf
+* LONI image data archive https://ida.loni.usc.edu/services/Menu/IdaData.jsp?project=
+* Radiology (Ultrasound, Mammographs, X-Ray, CT, MRI, fMRI, etc.)
+* Collaborative Informatics and Neuroimaging Suite (COINS) https://portal.mrn.org/micis/index.php?subsite=dx
+* The Cancer Imaging Archive (TCIA) http://www.cancerimagingarchive.net/ (Collections)
+* Alzheimer’s Disease Neuroimaging Initiative (ADNI) http://adni.loni.ucla.edu/
+* The Open Access Series of Imaging Studies (OASIS) http://www.oasis-brains.org/
+* Breast Cancer Digital Repository https://bcdr.eu/
+* DDSM: Digital Database for Screening Mammography http://marathon.csee.usf.edu/Mammography/Database.html
+* The Mammographic Image Analysis Society (MIAS) mini-database http://peipa.essex.ac.uk/info/mias.html
+* Mammography Image Databases 100 or more images of mammograms with ground truth. Additional images available by request, and links to several other mammography databases are provided http://marathon.csee.usf.edu/Mammography/Database.html
+* NLM HyperDoc Visible Human Project color, CAT and MRI image samples - over 30 images http://www.nlm.nih.gov/research/visible/visible_human.html
+* CT Scans for Colon Cancer https://wiki.cancerimagingarchive.net/display/Public/CT+COLONOGRAPHY#e88604ec5c654f60a897fa77906f88a6
+* [[BreastScreening](http://breastscreening.github.io/)] UTA4: Breast Cancer Medical Imaging DICOM Files Dataset & Resources (MG, US and MRI) https://github.com/MIMBCD-UI/dataset-uta4-dicom
+* [[MIMBCD-UI](http://mimbcd-ui.github.io/)] UTA7: Breast Cancer Medical Imaging DICOM Files Dataset & Resources (MG, US and MRI) https://github.com/MIMBCD-UI/dataset-uta7-dicom
+* [[Facebook AI + NYU FastMRI](https://fastmri.org/dataset/)] includes two types of MRI scans: knee MRIs and the brain (neuro) MRIs, containing training, validation, and masked test sets. Also includes PyTorch data loaders in open-sourced [GitHub Repository](https://github.com/facebookresearch/fastMRI/)
+* BCNB: Early Breast Cancer Core-Needle Biopsy WSI Dataset, https://bupt-ai-cz.github.io/BCNB/, https://github.com/bupt-ai-cz/BALNMP#bcnb-dataset
+
+### Histology and Histopathology (H&E, IHQ, ...)
+
+* The Cancer Genome Atlas (TCGA) http://cancergenome.nih.gov/ https://tcga-data.nci.nih.gov/tcga/
+* International Cancer Genome Consortium http://icgc.org, (Data portal) http://dcc.icgc.org/
+* Stanford Tissue Microarray Database (TMA) http://tma.im
+* MITOS dataset http://www.ipal.cnrs.fr/event/icpr-2012
+* Cancer Image Database (caIMAGE) https://emice.nci.nih.gov/caimage
+* DPA’s Whole Slide Imaging Repository https://digitalpathologyassociation.org/whole-slide-imaging-repository
+* ITK Analysis of Large Histology Datasets http://www.na-mic.org/Wiki/index.php/ITK_Analysis_of_Large_Histology_Datasets
+* Histology Photo Album http://www.histology-world.com/photoalbum/thumbnails.php?album=52
+* Slide Library of Virtual pathology, University of Leeds http://www.virtualpathology.leeds.ac.uk/
+* Aperio Images http://images.aperio.com/
+* HAPS Histology Image Database http://hapshistology.wikifoundry.com/
+* Microscopy (Cell, Cytology, Biology, Protein, Molecular, Fluorescence, etc.)
+* BDGP images from the FlyExpress database www.flyexpress.net
+* The UCSB Bio-Segmentation Benchmark dataset http://www.bioimage.ucsb.edu/research/biosegmentation
+* Pap Smear database http://mde-lab.aegean.gr/index.php/downloads
+* Histology (CIMA) dataset http://cmp.felk.cvut.cz/~borovji3/?page=dataset
+* ANHIR dataset https://anhir.grand-challenge.org/
+* Genome RNAi dataset http://www.genomernai.org/
+* Chinese Hamster Ovary cells (CHO) dataset http://www.chogenome.org/data.html
+* Locate Endogenus mouse sub-cellular organelles (END) database http://locate.imb.uq.edu.au/
+* 2D HeLa dataset (HeLa) dataset https://ome.grc.nia.nih.gov/iicbu2008/hela/index.html
+* Allen Brain Atlas http://www.brain-map.org/
+* 1000 Functional Connectomes Project http://fcon_1000.projects.nitrc.org/
+* The Cell Centered Database (CCDB) https://library.ucsd.edu/dc/collection/bb5940732k
+* The Encyclopedia of DNA Elements (ENCODE) http://genome.ucsc.edu/ENCODE/
+user guide: http://www.plosbiology.org/article/info:doi/10.1371/journal.pbio.1001046
+* The Human Protein Atlas: http://www.proteinatlas.org/
+* DRIVE: Digital Retinal Images for Vessel Extraction http://www.isi.uu.nl/Research/Databases/DRIVE/ (Ground truth)
+* El Salvador Atlas of Gastrointestinal VideoEndoscopy Images and Videos of hi-res of studies taken from Gastrointestinal Video endoscopy http://www.gastrointestinalatlas.com/
+* BCNB: Early Breast Cancer Core-Needle Biopsy WSI Dataset, https://bupt-ai-cz.github.io/BCNB/, https://github.com/bupt-ai-cz/BALNMP#bcnb-dataset
+
+### Databases you can use for benchmarking
+
+* http://peipa.essex.ac.uk/benchmark/databases/
+* http://mulan.sourceforge.net/datasets-mlc.html
+* https://archive.ics.uci.edu/ml/datasets.php
+* Datasets reporting formats for pathologists http://www.rcpath.org/publications-media/publications/datasets
+* DermNet - Skin disease atlas (23 image classes and 23,000 images): http://www.dermnet.com/
+
+### State of the art / Challenges
+
+* Grand Challenges in Medical Image Analysis https://grand-challenge.org/
+* Challenges in global health and development problems https://grandchallenges.org/#/map
+* Current state of the art of most used computer vision datasets: Who is the best at X? http://rodrigob.github.io/are_we_there_yet/build/
+* Automatic Non-rigid Histological Image Registration (ANHIR) challenge https://anhir.grand-challenge.org/
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/8-Interviewer/README.md b/ai-medical-chatbot-master/8-Interviewer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d66a15b23c87c2b41d840db9e71ace2e0b8af5b6
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/README.md
@@ -0,0 +1,82 @@
+# Medical Interviewer
+![](assets/2024-09-08-19-33-27.png)
+
+Medical Interviewer is an AI platform designed to simulate medical interviews based on psychological and medical knowledge. This platform is meant for educational, research, and preliminary assessment purposes but should **not** be used as a substitute for professional medical advice.
+
+## Disclaimer
+- This is a simulation tool, and information shared is confidential but **should not** include sensitive data.
+- Always seek professional medical advice from qualified healthcare providers.
+
+## Features
+### 1. Simulated Interviews
+- Conducts interviews with medically relevant and psychologically informed questions.
+- Interviews can be customized by selecting different interviewers with unique backgrounds and approaches:
+ - **Sarah**: Compassionate experience in trauma and family therapy.
+ - **Aaron**: Direct and specialized in stress management, trauma, and military-related cases.
+
+### 2. Natural Language Processing (NLP)
+- Uses **LangChain** to create NLP chains that drive the interview and report generation processes.
+- Understands user input and generates contextually relevant follow-up questions.
+
+### 3. Voice Interaction
+- Enables voice-based conversation by converting text to speech (TTS) and transcribing user responses using **Whisper** speech-to-text technology.
+
+### 4. Medical Report Generation
+- Automatically generates detailed medical reports based on the interview session.
+- Users can download the report in PDF format.
+
+### 5. Document Upload and Report Generation
+- Users can upload TXT, PDF, or DOCX files, and the system will generate a report from the document's content.
+
+### 6. Multi-Language Support
+- Supports interviews and report generation in multiple languages.
+
+### 7. Retrieval-Augmented Generation (RAG) and Document Retrieval
+- Utilizes **RAG** to retrieve relevant data from indexed medical and psychological resources such as DSM-5, PDM-2, Big Five Traits, etc., ensuring contextually accurate interview questions.
+
+## Usage
+
+### Start an Interview
+1. Choose an interviewer (Sarah or Aaron).
+2. Start the interview in either text or voice format.
+3. The system will guide you through a series of questions, designed to simulate a medical interview.
+
+### Upload Documents
+1. Navigate to the "Upload Document" tab.
+2. Upload your medical document (TXT, PDF, or DOCX).
+3. Generate a report based on the content of the document, which will be available for download.
+
+### Settings
+- Customize your experience by enabling or disabling audio interaction.
+- Choose between available interviewers, each providing a unique style and focus.
+
+## Technical Overview
+The system uses a combination of:
+- **NLP** for question generation and contextual relevance.
+- **FAISS Indexing** for document similarity and retrieval of relevant medical knowledge.
+- **RAG** to ensure contextually relevant and accurate interview processes.
+- **Gradio** for the user interface, allowing text and audio inputs with real-time responses.
+- **Whisper** and **TTS** for audio interaction, enabling real-time simulated voice-based interviews.
+
+## Installation
+
+To run the application locally:
+
+1. Clone the repository:
+ ```bash
+ git clone https://github.com/ruslanmv/ai-medical-chatbot
+ cd ./ai-medical-chatbot/8-Interviewer/hf
+ ```
+
+2. Install the dependencies:
+ ```bash
+ pip install -r requirements.txt
+ ```
+
+3. Launch the application:
+ ```bash
+ python app.py
+ ```
+
+## License
+This project is licensed under the MIT License. See the [LICENSE](../LICENSE.txt) file for more details.
diff --git a/ai-medical-chatbot-master/8-Interviewer/assets/2024-09-08-19-33-27.png b/ai-medical-chatbot-master/8-Interviewer/assets/2024-09-08-19-33-27.png
new file mode 100644
index 0000000000000000000000000000000000000000..01ab6378fbee94e9b0e04d17daf8cf4fb95f076c
Binary files /dev/null and b/ai-medical-chatbot-master/8-Interviewer/assets/2024-09-08-19-33-27.png differ
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/.gitattributes b/ai-medical-chatbot-master/8-Interviewer/hf/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/.gitattributes
@@ -0,0 +1,35 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/README.md b/ai-medical-chatbot-master/8-Interviewer/hf/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8be76cbd3eb80a4ba14def723e2f492655bfcf49
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/README.md
@@ -0,0 +1,12 @@
+---
+title: Medical Interviewer
+emoji: 👩🦳
+colorFrom: pink
+colorTo: yellow
+sdk: gradio
+sdk_version: 4.41.0
+app_file: app.py
+pinned: false
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/ai_config.py b/ai-medical-chatbot-master/8-Interviewer/hf/ai_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0af4e1fc2badf7e6b1762f816815f58468e62516
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/ai_config.py
@@ -0,0 +1,69 @@
+from io import BytesIO
+
+from langchain_openai import ChatOpenAI
+from openai import OpenAI
+import tiktoken
+import os
+from dotenv import load_dotenv
+import os
+# Load environment variables from .env file
+load_dotenv()
+
+# IBM Connection Parameters (using loaded env variables)
+openai_api_key = os.getenv("OPENAI_API_KEY")
+
+def n_of_questions():
+ n_of_questions = 25
+ return n_of_questions
+
+#openai_api_key = os.environ.get("openai_api_key")
+
+model = "gpt-4o-mini"
+
+def load_model(openai_api_key):
+ return ChatOpenAI(
+ model_name=model,
+ openai_api_key=openai_api_key,
+ temperature=0.5
+ )
+
+# Initialize the OpenAI client with the API key
+client = OpenAI(api_key=openai_api_key)
+
+
+def convert_text_to_speech(text, output, voice):
+ try:
+ # Convert the final text to speech
+ response = client.audio.speech.create(model="tts-1-hd", voice=voice, input=text)
+
+ if isinstance(output, BytesIO):
+ # If output is a BytesIO object, write directly to it
+ for chunk in response.iter_bytes():
+ output.write(chunk)
+ else:
+ # If output is a file path, open and write to it
+ with open(output, 'wb') as f:
+ for chunk in response.iter_bytes():
+ f.write(chunk)
+
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ # Fallback in case of error
+ response = client.audio.speech.create(model="tts-1-hd", voice=voice, input='Here is my Report.')
+
+ if isinstance(output, BytesIO):
+ for chunk in response.iter_bytes():
+ output.write(chunk)
+ else:
+ with open(output, 'wb') as f:
+ for chunk in response.iter_bytes():
+ f.write(chunk)
+
+
+def transcribe_audio(audio):
+ audio_file = open(audio, "rb")
+ transcription = client.audio.transcriptions.create(
+ model="whisper-1",
+ file=audio_file
+ )
+ return transcription.text
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/app.py b/ai-medical-chatbot-master/8-Interviewer/hf/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..041b384c8c73da8e56e4c28a8d5f4e6acf2e507f
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/app.py
@@ -0,0 +1,232 @@
+import gradio as gr
+import tempfile
+import os
+from pathlib import Path
+from io import BytesIO
+from settings import (
+ respond,
+ generate_random_string,
+ reset_interview,
+ generate_interview_report,
+ generate_report_from_file,
+ interview_history,
+ question_count,
+ language,
+)
+from ai_config import convert_text_to_speech, transcribe_audio, n_of_questions
+from prompt_instructions import get_interview_initial_message_sarah, get_interview_initial_message_aaron
+
+# Global variables
+temp_audio_files = []
+initial_audio_path = None
+selected_interviewer = "Sarah"
+audio_enabled = True
+
+def reset_interview_action(voice):
+ global question_count, interview_history, selected_interviewer
+ selected_interviewer = voice
+ question_count = 0
+ interview_history.clear()
+
+ if voice == "Sarah":
+ initial_message = get_interview_initial_message_sarah()
+ voice_setting = "alloy"
+ else:
+ initial_message = get_interview_initial_message_aaron()
+ voice_setting = "onyx"
+
+ initial_message = str(initial_message)
+
+ initial_audio_buffer = BytesIO()
+ convert_text_to_speech(initial_message, initial_audio_buffer, voice_setting)
+ initial_audio_buffer.seek(0)
+
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
+ temp_audio_path = temp_file.name
+ temp_file.write(initial_audio_buffer.getvalue())
+
+ temp_audio_files.append(temp_audio_path)
+
+ return (
+ [(None, initial_message[0] if isinstance(initial_message, tuple) else initial_message)],
+ gr.Audio(value=temp_audio_path, label=voice, autoplay=True, visible=False),
+ gr.Textbox(value="")
+ )
+
+def create_app():
+ global initial_audio_path, selected_interviewer, audio_enabled
+ # Initialize without any message history
+ initial_message = ""
+
+ with gr.Blocks(title="AI Medical Interviewer") as demo:
+ gr.Image(value="appendix/icon.jpeg", label='icon', width=20, scale=1, show_label=False, show_fullscreen_button=False,
+ show_download_button=False, show_share_button=False)
+ gr.Markdown(
+ """
+ # Medical Interviewer
+ This chatbot conducts medical interviews based on medical knowledge.
+ The interviewer will prepare a medical report based on the interview.
+ """
+ )
+
+ with gr.Tab("Interview"):
+ with gr.Row():
+ reset_button = gr.Button("Start Interview", size='sm', scale=1)
+ end_button = gr.Button("End Interview", size='sm', scale=1) # Added End Interview button
+ audio_output = gr.Audio(
+ label="Sarah",
+ scale=3,
+ autoplay=True,
+ visible=False, # Hides the audio but keeps it active
+ show_download_button=False,
+ )
+
+ # Chatbot initialized with no messages
+ chatbot = gr.Chatbot(value=[], label=f"Medical Interview📋")
+ with gr.Row():
+ msg = gr.Textbox(label="Type your message here...", scale=3)
+ audio_input = gr.Audio(sources=(["microphone"]), label="Record your message", type="filepath", scale=1)
+ send_button = gr.Button("Send")
+ pdf_output = gr.File(label="Download Report", visible=False)
+
+ def user(user_message, audio, history):
+ if audio is not None:
+ user_message = transcribe_audio(audio)
+ return "", None, history + [[user_message, None]]
+
+ def bot_response(chatbot, message):
+ global question_count, temp_audio_files, selected_interviewer, audio_enabled
+ question_count += 1
+
+ last_user_message = chatbot[-1][0] if chatbot else message
+
+ voice = "alloy" if selected_interviewer == "Sarah" else "onyx"
+ response, audio_buffer = respond(chatbot, last_user_message, voice, selected_interviewer)
+
+ for bot_message in response:
+ chatbot.append((None, bot_message[1]))
+
+ if isinstance(audio_buffer, BytesIO):
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
+ temp_audio_path = temp_file.name
+ temp_file.write(audio_buffer.getvalue())
+ temp_audio_files.append(temp_audio_path)
+ audio_output = gr.Audio(value=temp_audio_path, label=selected_interviewer, autoplay=audio_enabled, visible=False)
+ else:
+ audio_output = gr.Audio(value=audio_buffer, label=selected_interviewer, autoplay=audio_enabled, visible=False)
+
+ if question_count >= n_of_questions():
+ conclusion_message = "Thank you for participating in this interview. We have reached the end of our session. I hope this conversation has been helpful. Take care!"
+ chatbot.append((None, conclusion_message))
+
+ conclusion_audio_buffer = BytesIO()
+ convert_text_to_speech(conclusion_message, conclusion_audio_buffer, voice)
+ conclusion_audio_buffer.seek(0)
+
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
+ temp_audio_path = temp_file.name
+ temp_file.write(conclusion_audio_buffer.getvalue())
+ temp_audio_files.append(temp_audio_path)
+ audio_output = gr.Audio(value=temp_audio_path, label=selected_interviewer, autoplay=audio_enabled, visible=False)
+
+ report_content, pdf_path = generate_interview_report(interview_history, language)
+ chatbot.append((None, f"Interview Report:\n\n{report_content}"))
+
+ return chatbot, audio_output, gr.File(visible=True, value=pdf_path)
+
+ return chatbot, audio_output, gr.File(visible=False)
+
+ # Function to reset and start the interview, which populates the chatbot with the initial message
+ def start_interview():
+ global selected_interviewer
+ return reset_interview_action(selected_interviewer)
+
+ # Function to end the interview
+ def end_interview(chatbot):
+ chatbot.append((None, "The interview has been ended by the user."))
+ return chatbot, gr.Audio(visible=False), gr.Textbox(value="")
+
+ # Bind actions to buttons
+ reset_button.click(
+ start_interview,
+ inputs=[],
+ outputs=[chatbot, audio_output, msg]
+ )
+
+ end_button.click(
+ end_interview,
+ inputs=[chatbot],
+ outputs=[chatbot, audio_output, msg]
+ )
+
+ msg.submit(user, [msg, audio_input, chatbot], [msg, audio_input, chatbot], queue=False).then(
+ bot_response, [chatbot, msg], [chatbot, audio_output, pdf_output]
+ )
+
+ send_button.click(user, [msg, audio_input, chatbot], [msg, audio_input, chatbot], queue=False).then(
+ bot_response, [chatbot, msg], [chatbot, audio_output, pdf_output]
+ )
+
+ with gr.Tab("Settings"):
+ gr.Markdown('Configure your settings below:')
+ audio_toggle = gr.Checkbox(label="Enable Audio", value=True)
+ interviewer_radio = gr.Radio(["Sarah", "Aaron"], label="Select Interviewer", value="Sarah")
+
+ def update_settings(audio_status, interviewer_choice):
+ global audio_enabled, selected_interviewer
+ audio_enabled = audio_status
+ selected_interviewer = interviewer_choice
+ return f"Settings updated: Audio {'Enabled' if audio_enabled else 'Disabled'}, Interviewer: {selected_interviewer}"
+
+ settings_button = gr.Button("Apply Settings")
+ settings_message = gr.Textbox(visible=True)
+
+ settings_button.click(
+ update_settings,
+ inputs=[audio_toggle, interviewer_radio],
+ outputs=[settings_message]
+ )
+
+ with gr.Tab("Upload Document"):
+ gr.Markdown('Please upload a document that contains content written about a patient or by the patient.')
+ file_input = gr.File(label="Upload a TXT, PDF, or DOCX file")
+ language_input = 'English'
+ generate_button = gr.Button("Generate Report")
+ report_output = gr.Textbox(label="Generated Report", lines=100, visible=False)
+ pdf_output = gr.File(label="Download Report", visible=True)
+
+ def generate_report_and_pdf(file, language):
+ report_content, pdf_path = generate_report_from_file(file, language)
+ return report_content, pdf_path, gr.File(visible=True)
+
+ generate_button.click(
+ generate_report_and_pdf,
+ inputs=[file_input],
+ outputs=[report_output, pdf_output, pdf_output]
+ )
+
+ with gr.Tab("Description"):
+ with open('appendix/description.txt', 'r', encoding='utf-8') as file:
+ description_txt = file.read()
+ gr.Markdown(description_txt)
+ gr.Image(value="appendix/diagram.png", label='diagram', width=700, scale=1, show_label=False)
+
+ return demo
+
+# Clean up function
+def cleanup():
+ global temp_audio_files, initial_audio_path
+ for audio_file in temp_audio_files:
+ if os.path.exists(audio_file):
+ os.unlink(audio_file)
+ temp_audio_files.clear()
+
+ if initial_audio_path and os.path.exists(initial_audio_path):
+ os.unlink(initial_audio_path)
+
+if __name__ == "__main__":
+ app = create_app()
+ try:
+ app.launch()
+ finally:
+ cleanup()
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/appendix/Psi.png b/ai-medical-chatbot-master/8-Interviewer/hf/appendix/Psi.png
new file mode 100644
index 0000000000000000000000000000000000000000..109859731ec1848a9e798b3d9aaa0e5aaf4ffca2
Binary files /dev/null and b/ai-medical-chatbot-master/8-Interviewer/hf/appendix/Psi.png differ
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/appendix/description.txt b/ai-medical-chatbot-master/8-Interviewer/hf/appendix/description.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7e7ad1c85041e2d3cb4ecb9feaa4e36f8a381f59
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/appendix/description.txt
@@ -0,0 +1,75 @@
+# Medical Interviewer
+This chatbot conducts medical interviews based on psychological knowledge.
+
+The interviewer will prepare a medical report based on the interview.
+
+* Please note that this is a simulation and should not be used as a substitute for professional medical advice.
+* It is important to emphasize that any information shared is confidential and cannot be accessed.
+* In any case, it is recommended not to share sensitive information.
+
+
+**Medical Interviewer** is an AI platform designed to simulate medical interviews. It leverages NLP and speech technologies to emulate a medical psychologist, offering insightful assessments and generating detailed medical reports.
+
+This platform is ideal for educational, research, and preliminary assessment purposes but should not replace professional medical advice.
+
+## Features
+
+**Key Features**:
+- **Simulated Interviews**: Conducts interviews with focused medically relevant questions.
+- **Natural Language Processing**: Understands and generates contextually relevant questions.
+- **LangChain**: Create NLP chains for interview and report generation.
+- **Audio Interaction**: Voice conversation simulation where the user can talk to the bot in a way that simulates real conversation or evidence.
+- **Report Generation**: Automatically creates comprehensive medical reports after each session.
+- **Document Upload for Reports**: Generates reports from uploaded TXT, PDF, or DOCX files.
+- **Multi-language Support**: Conducts interviews and generates reports in the user's preferred language.
+- **Selectable Interviewers**: Users can select their preferred interviewer, each with a different professional background, experience, and temperament. Options include:
+ - Sarah: An empathic, compassionate medical with over 30 years of experience, specializing in trauma, anxiety disorders, and family therapy.
+ - Aaron: A tough minded, medical with over 15 years of experience, specializing in stress, trauma, and high-performance demands, with a background as a military officer.
+
+
+## Retrieval-Augmented Generation (RAG) and Document Retrieval Process
+
+**Retrieval-Augmented Generation (RAG)** is a method that combines the strengths of retrieval-based and generation-based approaches. RAG helps to ensure that the interview questions generated by the AI are both contextually relevant and grounded in authoritative sources - This optimizes and reduces the response time.
+
+1. **Document Embeddings**: The documents are converted into embeddings using OpenAI’s embedding models. These embeddings capture the semantic meaning of the text and are used to facilitate efficient retrieval.
+2. **FAISS Indexing**: The embeddings are stored in a FAISS (Facebook AI Similarity Search) index. FAISS is optimized for similarity search and clustering of dense vectors, making it ideal for this purpose.
+3. **Query Embedding**: When a user input or interview context is provided, it is also converted into an embedding.
+4. **Similarity Search**: The query embedding is used to search the FAISS index to retrieve the most relevant documents based on their embeddings.
+5. **Top-K Retrieval**: The system retrieves the top-K documents that are most similar to the user’s query embedding. These documents are then used to generate the next interview question, ensuring that the responses are based on relevant and accurate information.
+
+## Documents and Knowledge Database
+
+The platform uses a rich set of documents and knowledge bases to inform the AI’s questioning and reporting processes. These documents include:
+
+- **DSM-5 (Diagnostic and Statistical Manual of Mental Disorders, 5th Edition)**: Provides standardized criteria for the diagnosis of mental health conditions.
+- **PDM-2 (Psychodynamic Diagnostic Manual, 2nd Edition)**: Offers a psychodynamic perspective on mental health diagnosis.
+- **Personalities Descriptions**: Detailed descriptions of various personality types and traits.
+- **Defence Mechanisms**: Information on psychological strategies used by individuals to cope with reality and maintain self-image.
+- **Big Five Traits**: Descriptions of the five-factor model of personality traits.
+- **Attachment Styles**: Framework for understanding different types of attachment in interpersonal relationships.
+- **Interview Conduction Guides for medical Psychologists**: Guidelines and best practices for conducting medical interviews.
+
+These documents are processed and indexed, enabling the AI to retrieve relevant excerpts during the interview to generate questions that are grounded in established psychological knowledge.
+
+## Contextual and Historical Relevance
+
+Throughout the interview process, the AI uses all chat history to ensure that each follow-up question is contextually relevant. By leveraging both the immediate user input and the full history of the conversation, the AI can provide a coherent and comprehensive interview experience. The use of RAG ensures that the follow-up questions are informed not only by the user's previous responses but also by the most relevant and authoritative information available in the knowledge base.
+
+## Human-like simulated environment
+It supports audio interactions by converting text questions into speech and transcribing user audio responses into text, facilitated by OpenAI’s text-to-speech (TTS) and Whisper speech-to-text technologies. This creates a simulated environment for real-like conversational interviews, making the interactions more human-like.
+
+### Interview Tab
+
+The session starts with an introductory message delivered in both text and audio formats. Users respond by typing or recording audio responses, which the AI processes to generate and return relevant follow-up questions based on context and the retrieved documents. The conversation continues until a predetermined number of questions have been asked. At the end of the session, a detailed medical report is generated and available for download as a PDF.
+
+### Upload Document Tab
+
+Users can upload existing documents and specify their preferred language. The system analyzes the document content and generates a detailed medical report, which can be displayed and downloaded.
+
+## Disclaimer
+
+This platform is a simulation and should not replace professional medical advice. Always seek advice from a qualified healthcare provider for medical concerns.
+
+---
+
+**medical Interviewer ** stands as a testament to the potential of advanced AI technologies in simulating medical psychology interviews and generating detailed reports. For technical details, refer to the in-code documentation. This platform offers a valuable tool for educational and research purposes by providing an enriching and interactive user experience.
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/appendix/diagram.png b/ai-medical-chatbot-master/8-Interviewer/hf/appendix/diagram.png
new file mode 100644
index 0000000000000000000000000000000000000000..29fb83523d641aa2c4c6bc69bdd49480ad4a07e6
Binary files /dev/null and b/ai-medical-chatbot-master/8-Interviewer/hf/appendix/diagram.png differ
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/appendix/icon.jpeg b/ai-medical-chatbot-master/8-Interviewer/hf/appendix/icon.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..84aae860ac7327ed4f64291d13378aa389695246
Binary files /dev/null and b/ai-medical-chatbot-master/8-Interviewer/hf/appendix/icon.jpeg differ
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/knowledge_retrieval.py b/ai-medical-chatbot-master/8-Interviewer/hf/knowledge_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ca0a1bc6d579d77a59f9b01d6d0d13821591105
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/knowledge_retrieval.py
@@ -0,0 +1,91 @@
+import random
+from langchain_community.vectorstores import FAISS
+from langchain_openai import OpenAIEmbeddings
+from langchain.chains import create_retrieval_chain
+from langchain.chains.combine_documents import create_stuff_documents_chain
+from langchain_core.prompts import ChatPromptTemplate
+from langchain.retrievers import EnsembleRetriever
+from ai_config import n_of_questions, openai_api_key
+from prompt_instructions import get_interview_prompt_sarah, get_interview_prompt_aaron, get_report_prompt
+
+n_of_questions = n_of_questions()
+
+def setup_knowledge_retrieval(llm, language='english', voice='Sarah'):
+ embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
+
+ documents_faiss_index = FAISS.load_local("knowledge/faiss_index_all_documents", embedding_model,
+ allow_dangerous_deserialization=True)
+
+ documents_retriever = documents_faiss_index.as_retriever()
+
+ combined_retriever = EnsembleRetriever(
+ retrievers=[documents_retriever]
+ )
+
+ if voice == 'Sarah':
+ interview_prompt = ChatPromptTemplate.from_messages([
+ ("system", get_interview_prompt_sarah(language, n_of_questions)),
+ ("human", "{input}")
+ ])
+ else:
+ interview_prompt = ChatPromptTemplate.from_messages([
+ ("system", get_interview_prompt_aaron(language, n_of_questions)),
+ ("human", "{input}")
+ ])
+
+ report_prompt = ChatPromptTemplate.from_messages([
+ ("system", get_report_prompt(language)),
+ ("human", "Please provide a concise clinical report based on the interview.")
+ ])
+
+ interview_chain = create_stuff_documents_chain(llm, interview_prompt)
+ report_chain = create_stuff_documents_chain(llm, report_prompt)
+
+ interview_retrieval_chain = create_retrieval_chain(combined_retriever, interview_chain)
+ report_retrieval_chain = create_retrieval_chain(combined_retriever, report_chain)
+
+ return interview_retrieval_chain, report_retrieval_chain, combined_retriever
+
+
+def get_next_response(interview_chain, message, history, question_count):
+ combined_history = "\n".join(history)
+
+ # Check if the interview should end
+ if question_count >= n_of_questions:
+ return "Thank you for your responses. I will now prepare a report."
+
+ # Generate the next question
+ result = interview_chain.invoke({
+ "input": f"Based on the patient's last response: '{message}', and considering the full interview history, ask a specific, detailed question that hasn't been asked before and is relevant to the patient's situation.",
+ "history": combined_history,
+ "question_number": question_count + 1 # Increment question number here
+ })
+
+ next_question = result.get("answer", "Could you provide more details on that?")
+
+ # Update history with the new question and response
+ history.append(f"Q{question_count + 1}: {next_question}")
+ history.append(f"A{question_count + 1}: {message}")
+
+ return next_question
+
+
+def generate_report(report_chain, history, language):
+ combined_history = "\n".join(history)
+
+ result = report_chain.invoke({
+ "input": "Please provide a clinical report based on the interview.",
+ "history": combined_history,
+ "language": language
+ })
+
+ return result.get("answer", "Unable to generate report due to insufficient information.")
+
+
+def get_initial_question(interview_chain):
+ result = interview_chain.invoke({
+ "input": "What should be the first question in a clinical psychology interview?",
+ "history": "",
+ "question_number": 1
+ })
+ return result.get("answer", "Could you tell me a little bit about yourself and what brings you here today?")
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/prompt_instructions.py b/ai-medical-chatbot-master/8-Interviewer/hf/prompt_instructions.py
new file mode 100644
index 0000000000000000000000000000000000000000..394e71466d372d635da6388f2f3188506a03352c
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/prompt_instructions.py
@@ -0,0 +1,162 @@
+from datetime import datetime
+from ai_config import n_of_questions
+current_datetime = datetime.now()
+current_date = current_datetime.strftime("%Y-%m-%d")
+
+n_of_questions = n_of_questions()
+
+
+def get_interview_initial_message_sarah():
+ return f"""Hello, I'm Sarah, an AI clinical psychologist, and I'll be conducting a clinical interview with you.
+ I will ask you about {n_of_questions} questions.
+ Feel free to share as much or as little as you're comfortable with.
+ Could you please tell me which language you prefer to speak or conduct this interview in? """
+
+def get_interview_initial_message_aaron():
+ return f"""Hello, I'm Aaron, an AI clinical psychologist. I'll be conducting a brief interview with you.
+ Which language do you prefer for this interview? my mother tongue language is English, so bear with me if there are any mistakes."""
+
+
+def get_interview_prompt_sarah(language, n_of_questions):
+ return f"""You are Sarah, an empathic and compassionate Female Psychologist or Psychiatrist, conducting a clinical interview in {language}.
+
+A highly experienced and dedicated Clinical Psychologist with over 30 years of experience in clinical practice and research.
+Specializing in trauma, anxiety disorders, and family therapy, Sarah has a proven track record of successfully treating a wide range of psychological conditions.
+Her deep commitment to patient care and mental health advocacy has driven her to develop innovative therapeutic approaches and lead community mental health initiatives.
+Sarah's extensive career is marked by her unwavering dedication to giving back to the community.
+She has been actively involved in various community service efforts, including several years of work with children with disabilities and autistic children.
+Her compassionate approach and ability to connect with patients of all ages have made her a respected figure in the field of psychology.
+Sarah is not only a skilled clinician but also a passionate advocate for mental health, continuously striving to improve the lives of those she serves.
+
+Use the following context and interview history to guide your response:
+
+Context from knowledge base: {{context}}
+
+Previous interview history:
+{{history}}
+
+Current question number: {{question_number}}
+
+Respond to the patient's input briefly and directly in {language}.
+Ask a specific, detailed question that hasn't been asked before.
+You must remember all the previous answers given by the patient, and use this information if necessary.
+If you perceive particularly special, or unusual, or strange things in the answers that require deepening or in-depth understanding - ask about it or direct your question to get answers about it and clarify the matter - this information maybe benefitial and may hint about the patient personality or traits.
+The first question is to ask for the patient name.
+The second question is to ask for age.
+The third question is to ask where they live.
+The fourth questions is to ask what they does for work.
+The fifth question is to ask about the nature of the relationship with their parents.
+Keep in mind that you have {n_of_questions} total number of questions.
+After {n_of_questions} interactions, indicate that you will prepare a report based on the gathered information."""
+
+
+def get_interview_prompt_aaron(language, n_of_questions):
+ return f"""You are Aaron, a not so much empathic, tough, and impatient Male Psychologist, Coach, and Mentor, conducting a clinical interview in {language}.
+
+ Aaron Professional Resume or Summary:
+ Aaron is a highly experienced clinical psychologist with over 15 years of expertise in treating individuals dealing with stress, trauma, and high-performance demands.
+ His background as an army officer in the special forces, where he served for 20 years, provides him with a unique understanding of the mental health challenges faced by soldiers.
+ In addition to his work with military personnel, Aaron extends his practice to athletes, entrepreneurs, and business professionals, offering specialized psychological support that helps them achieve peak performance while managing stress and mental well-being.
+ As a coach and mentor, Aaron is committed to guiding his clients through personal and professional challenges, fostering resilience, and promoting mental wellness.
+
+ Use the following context and interview history to guide your response:
+
+ Context from knowledge base: {{context}}
+
+ Previous interview history:
+ {{history}}
+
+ Current question number: {{question_number}}
+
+ Respond to the patient's input briefly and directly in {language}.
+ Ask a specific, detailed question that hasn't been asked before.
+ You must remember all the previous answers given by the patient, and use this information if necessary.
+ If you perceive particularly special, or unusual, or strange things in the answers that require deepening or in-depth understanding - ask about it or direct your question to get answers about it and clarify the matter - this information maybe benefitial and may hint about the patient personality or traits.
+ The first question is to ask for the patient name.
+ The second question is to ask for age.
+ The third question is to ask where they live.
+ The fourth questions is to ask what they does for work.
+ The fifth question is to ask about the nature of the relationship with their parents.
+ Keep in mind that you have {n_of_questions} total number of questions.
+ After {n_of_questions} interactions, indicate that you will prepare a report based on the gathered information."""
+
+def get_report_prompt(language):
+ return f"""You are a Psychologist or Psychiatrist preparing a clinical report in {language}.
+Use the following context and interview history to create your report.
+Keep the report concise and focused on the key observations:
+
+Context from knowledge base: {{context}}
+
+Complete interview history:
+{{history}}
+
+Prepare a brief clinical report in {language} based strictly on the information gathered during the interview.
+Date to specify in the report: {current_date}
+- Specify name, place of living, and current occupation if available.
+- Use only the terms, criteria for diagnosis, and categories for clinical diagnosis or classifications
+that are present in the provided knowledge base. Do not introduce any external information or terminology.
+* In your diagnosis, you must be very careful. That is, you need to have enough evidence and information to rate or diagnose a patient.
+* Your diagnoses must be fact-based when they are implied by what the speakers are saying.
+* Write technical, clinical or professional terms only in the English language.
+* As a rule, in cases where there is little information about the patient through the conversation or through
+the things they say, the diagnosis will be more difficult, and the ratings will be lower,
+because it is difficult to draw conclusions when our information about the patient is scarce.
+be very selective and careful with your facts that you write or provide in the report.
+in such a case, this also must be mentioned and taken into consideration.
+* Do not provide any clinical diagnosis or any conclusions in the reports if there is not enough information that the patient provide.
+* Any diagnosis or interpretation requires the presentation of facts, foundations, and explanations.
+* You can also give examples or quotes.
+* There are two parts for the report - main report and additional report.
+* Structure the main report to include observed symptoms, potential diagnoses (if applicable), and any other
+relevant clinical observations, all within the framework of the given knowledge.
+
+First, write the main report, than, in addition to the main report, add the following sections as the additional report:
+- An overall clinical impression
+- Dominant personality characteristics
+- Style of communication
+- What mainly preoccupies them - themes or topics that preoccupy them in particular
+- Possible personal weaknesses or triggers
+- Defense Mechanisms
+- How they are likely to react to stressful or emotionally charged situations or events
+- How they might deal with unexpected situations or events
+- How they might behave in a group vs alone
+- How they might behave in intimate relationships, and which partners they usually are drawn or attracted to. these unconscious choices may trigger past events or childhood experiences.
+- How will they function in work environments, and will they be able to contribute and perform properly and over time in a stable manner.
+- Degree of psychological mental health assessment
+- What will the experience be in general to meet such a person
+- Other things or further assessments that can be examined from a psychological perspective, and in which situations it is necessary to examine the person's reactions in order to get more indications of a diagnosis of their personality
+- The type of treatment that is recommended.
+
+Furthermore, include the following:
+
+Big Five Traits (ratings of 0-10):
+Extraversion: [rating]
+Agreeableness: [rating]
+Conscientiousness: [rating]
+Neuroticism: [rating]
+Openness: [rating]
+Big Five Traits explanation: [explanation]
+
+Personality Disorders or Styles (ratings of 0-4):
+Depressed: [rating]
+Paranoid: [rating]
+Schizoid-Schizotypal: [rating]
+Antisocial-Psychopathic: [rating]
+Borderline-Dysregulated: [rating]
+Narcissistic: [rating]
+Anxious-Avoidant: [rating]
+Dependent-Victimized: [rating]
+Hysteric-Histrionic: [rating]
+Obsessional: [rating]
+Personality Disorders or Styles explanation: [explanation]
+
+Attachment Styles (ratings of 0-10):
+Secured: [rating]
+Anxious-Preoccupied: [rating]
+Dismissive-Avoidant: [rating]
+Fearful-Avoidant: [rating]
+Avoidance: [rating]
+Positive view toward the Self: [rating]
+Positive view toward Others: [rating]
+Attachment Styles explanation: [explanation]
+"""
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/requirements.txt b/ai-medical-chatbot-master/8-Interviewer/hf/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..be2f6a8521350d799966f8f2ecb87a6141e4e4fe
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/requirements.txt
@@ -0,0 +1,20 @@
+python-dotenv==1.0.1
+pandas==2.1.4
+langchain==0.2.6
+langchain-openai==0.1.14
+langchain-core==0.2.11
+langchain-ibm==0.1.8
+langchain-community==0.2.6
+ibm-watson-machine-learning==1.0.359
+ipykernel
+notebook
+urllib3
+requests==2.32.0
+PyPDF2
+python-docx
+reportlab
+openai
+faiss-cpu
+cryptography
+pymysql
+scikit-learn
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/8-Interviewer/hf/settings.py b/ai-medical-chatbot-master/8-Interviewer/hf/settings.py
new file mode 100644
index 0000000000000000000000000000000000000000..9efa4f37a37d8d28543ebae414126b36d928c4bb
--- /dev/null
+++ b/ai-medical-chatbot-master/8-Interviewer/hf/settings.py
@@ -0,0 +1,245 @@
+import traceback
+from datetime import datetime
+from pathlib import Path
+import os
+import random
+import string
+import tempfile
+import re
+import io
+import PyPDF2
+import docx
+from reportlab.pdfgen import canvas
+from reportlab.lib.pagesizes import letter
+from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
+from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
+from reportlab.lib.enums import TA_JUSTIFY
+from ai_config import n_of_questions, load_model, openai_api_key, convert_text_to_speech
+from knowledge_retrieval import setup_knowledge_retrieval, generate_report
+
+# Initialize settings
+n_of_questions = n_of_questions()
+current_datetime = datetime.now()
+human_readable_datetime = current_datetime.strftime("%B %d, %Y at %H:%M")
+current_date = current_datetime.strftime("%Y-%m-%d")
+
+# Initialize the model and retrieval chain
+try:
+ llm = load_model(openai_api_key)
+ interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(llm)
+ knowledge_base_connected = True
+ print("Successfully connected to the knowledge base.")
+except Exception as e:
+ print(f"Error initializing the model or retrieval chain: {str(e)}")
+ knowledge_base_connected = False
+ print("Falling back to basic mode without knowledge base.")
+
+question_count = 0
+interview_history = []
+last_audio_path = None # Variable to store the path of the last audio file
+initial_audio_path = None # Variable to store the path of the initial audio file
+language = None
+
+def generate_random_string(length=5):
+ return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
+def respond(message, history, voice, selected_interviewer):
+ global question_count, interview_history, combined_retriever, last_audio_path, initial_audio_path, language, interview_retrieval_chain, report_retrieval_chain
+
+ if not isinstance(history, list):
+ history = []
+ if not history or not history[-1]:
+ history.append(["", ""])
+
+ # Extract the actual message text
+ if isinstance(message, list):
+ message = message[-1][0] if message and isinstance(message[-1], list) else message[-1]
+
+ question_count += 1
+ interview_history.append(f"Q{question_count}: {message}")
+ history_str = "\n".join(interview_history)
+ print("Starting interview", question_count)
+
+ try:
+ if knowledge_base_connected:
+ if question_count == 1:
+ # Capture the language from the first response
+ language = message.strip().lower()
+ # Reinitialize the interview chain with the new language
+ interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(
+ llm, language, selected_interviewer)
+
+ if question_count < n_of_questions:
+ result = interview_retrieval_chain.invoke({
+ "input": f"Based on the patient's statement: '{message}', what should be the next question?",
+ "history": history_str,
+ "question_number": question_count + 1,
+ "language": language
+ })
+ question = result.get("answer", f"Can you tell me more about that? (in {language})")
+ else:
+ result = generate_report(report_retrieval_chain, interview_history, language)
+ question = result
+ speech_file_path = None # Skip audio generation for the report
+
+ if question:
+ random_suffix = generate_random_string()
+ speech_file_path = Path(__file__).parent / f"question_{question_count}_{random_suffix}.mp3"
+ convert_text_to_speech(question, speech_file_path, voice)
+ print(f"Question {question_count} saved as audio at {speech_file_path}")
+
+ # Remove the last audio file if it exists
+ if last_audio_path and os.path.exists(last_audio_path):
+ os.remove(last_audio_path)
+ last_audio_path = speech_file_path
+ else:
+ speech_file_path = None # Skip audio generation for the report
+
+ else:
+ # Fallback mode without knowledge base
+ question = f"Can you elaborate on that? (in {language})"
+ if question_count < n_of_questions:
+ speech_file_path = Path(__file__).parent / f"question_{question_count}.mp3"
+ convert_text_to_speech(question, speech_file_path, voice)
+ print(f"Question {question_count} saved as audio at {speech_file_path}")
+
+ if last_audio_path and os.path.exists(last_audio_path):
+ os.remove(last_audio_path)
+ last_audio_path = speech_file_path
+ else:
+ speech_file_path = None
+
+ history[-1][1] = f"{question}"
+
+ # Remove the initial question audio file after the first user response
+ if initial_audio_path and os.path.exists(initial_audio_path):
+ os.remove(initial_audio_path)
+ initial_audio_path = None
+
+ # Clean up older files based on question_count
+ if question_count > 1:
+ previous_audio_path = Path(__file__).parent / f"question_{question_count-1}_{random_suffix}.mp3"
+ if os.path.exists(previous_audio_path):
+ os.remove(previous_audio_path)
+
+ return history, str(speech_file_path) if speech_file_path else None
+
+ except Exception as e:
+ print(f"Error in retrieval chain: {str(e)}")
+ print(traceback.format_exc())
+ return history, None
+
+
+
+
+def reset_interview():
+ """Reset the interview state."""
+ global question_count, interview_history, last_audio_path, initial_audio_path
+ question_count = 0
+ interview_history = []
+ if last_audio_path and os.path.exists(last_audio_path):
+ os.remove(last_audio_path)
+ last_audio_path = None
+ initial_audio_path = None
+
+
+def read_file(file):
+ if file is None:
+ return "No file uploaded"
+
+ if isinstance(file, str):
+ with open(file, 'r', encoding='utf-8') as f:
+ return f.read()
+
+ if hasattr(file, 'name'): # Check if it's a file-like object
+ if file.name.endswith('.txt'):
+ return file.content
+ elif file.name.endswith('.pdf'):
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.content))
+ return "\n".join(page.extract_text() for page in pdf_reader.pages)
+ elif file.name.endswith('.docx'):
+ doc = docx.Document(io.BytesIO(file.content))
+ return "\n".join(paragraph.text for paragraph in doc.paragraphs)
+ else:
+ return "Unsupported file format"
+
+ return "Unable to read file"
+
+def generate_report_from_file(file, language):
+ try:
+ file_content = read_file(file)
+ if file_content == "No file uploaded" or file_content == "Unsupported file format" or file_content == "Unable to read file":
+ return file_content
+
+ file_content = file_content[:100000]
+
+ report_language = language.strip().lower() if language else "english"
+ print('preferred language:', report_language)
+ print(f"Generating report in language: {report_language}") # For debugging
+
+ # Reinitialize the report chain with the new language
+ _, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
+
+ result = report_retrieval_chain.invoke({
+ "input": "Please provide a clinical report based on the following content:",
+ "history": file_content,
+ "language": report_language
+ })
+ report_content = result.get("answer", "Unable to generate report due to insufficient information.")
+ pdf_path = create_pdf(report_content)
+ return report_content, pdf_path
+ except Exception as e:
+ return f"An error occurred while processing the file: {str(e)}", None
+
+
+def generate_interview_report(interview_history, language):
+ try:
+ report_language = language.strip().lower() if language else "english"
+ print('preferred report_language language:', report_language)
+ _, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
+
+ result = report_retrieval_chain.invoke({
+ "input": "Please provide a clinical report based on the following interview:",
+ "history": "\n".join(interview_history),
+ "language": report_language
+ })
+ report_content = result.get("answer", "Unable to generate report due to insufficient information.")
+ pdf_path = create_pdf(report_content)
+ return report_content, pdf_path
+ except Exception as e:
+ return f"An error occurred while generating the report: {str(e)}", None
+
+def create_pdf(content):
+
+ random_string = generate_random_string()
+
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f'_report.pdf')
+ doc = SimpleDocTemplate(temp_file.name, pagesize=letter)
+ styles = getSampleStyleSheet()
+
+ # Create a custom style for bold text
+ bold_style = ParagraphStyle('Bold', parent=styles['Normal'], fontName='Helvetica-Bold', fontSize=10)
+
+ # Create a custom style for normal text with justification
+ normal_style = ParagraphStyle('Normal', parent=styles['Normal'], alignment=TA_JUSTIFY)
+
+ flowables = []
+
+ for line in content.split('\n'):
+ # Use regex to find words surrounded by **
+ parts = re.split(r'(\*\*.*?\*\*)', line)
+ paragraph_parts = []
+
+ for part in parts:
+ if part.startswith('**') and part.endswith('**'):
+ # Bold text
+ bold_text = part.strip('**')
+ paragraph_parts.append(Paragraph(bold_text, bold_style))
+ else:
+ # Normal text
+ paragraph_parts.append(Paragraph(part, normal_style))
+
+ flowables.extend(paragraph_parts)
+ flowables.append(Spacer(1, 12)) # Add space between paragraphs
+
+ doc.build(flowables)
+ return temp_file.name
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/Chatbot-Medical-Llama3-v2.ipynb b/ai-medical-chatbot-master/Chatbot-Medical-Llama3-v2.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8b074de88cb3c6e625e28682c05999b4d7eba448
--- /dev/null
+++ b/ai-medical-chatbot-master/Chatbot-Medical-Llama3-v2.ipynb
@@ -0,0 +1,195 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "toc_visible": true,
+ "machine_shape": "hm",
+ "gpuType": "A100",
+ "authorship_tag": "ABX9TyNMzCSw8XLVSOI/aj2QMEti",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Medical AI Chatbot\n",
+ "## [ruslanmv/Medical-Llama3-v2](https://github.com/ruslanmv/ai-medical-chatbot/blob/master/Chatbot-Medical-Llama3-v2.ipynb)"
+ ],
+ "metadata": {
+ "id": "D2JxjUcy8nZg"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from IPython.display import clear_output\n",
+ "!pip install bitsandbytes accelerate gradio\n",
+ "clear_output()"
+ ],
+ "metadata": {
+ "id": "eS2NsgQgvhZQ"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
+ "import torch\n",
+ "\n",
+ "# Define BitsAndBytesConfig\n",
+ "bnb_config = BitsAndBytesConfig(load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16)\n",
+ "\n",
+ "# Model name\n",
+ "model_name = \"ruslanmv/Medical-Llama3-v2\"\n",
+ "\n",
+ "# Load tokenizer and model with BitsAndBytesConfig\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, bnb_config=bnb_config)\n",
+ "model = AutoModelForCausalLM.from_pretrained(model_name, config=bnb_config)\n",
+ "\n",
+ "# Ensure model is on the correct device\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "model.to(device)"
+ ],
+ "metadata": {
+ "id": "teoE-Zmv4LlP"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Define the respond function\n",
+ "def respond(\n",
+ " message,\n",
+ " history: list[tuple[str, str]],\n",
+ " system_message,\n",
+ " max_tokens,\n",
+ " temperature,\n",
+ " top_p,\n",
+ "):\n",
+ " messages = [{\"role\": \"system\", \"content\": system_message}]\n",
+ "\n",
+ " for val in history:\n",
+ " if val[0]:\n",
+ " messages.append({\"role\": \"user\", \"content\": val[0]})\n",
+ " if val[1]:\n",
+ " messages.append({\"role\": \"assistant\", \"content\": val[1]})\n",
+ "\n",
+ " messages.append({\"role\": \"user\", \"content\": message})\n",
+ "\n",
+ " # Format the conversation as a single string for the model\n",
+ " prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
+ " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, padding=True, max_length=1000)\n",
+ "\n",
+ " # Move inputs to device\n",
+ " input_ids = inputs['input_ids'].to(device)\n",
+ " attention_mask = inputs['attention_mask'].to(device)\n",
+ "\n",
+ " # Generate the response\n",
+ " with torch.no_grad():\n",
+ " outputs = model.generate(\n",
+ " input_ids=input_ids,\n",
+ " attention_mask=attention_mask,\n",
+ " max_length=max_tokens,\n",
+ " temperature=temperature,\n",
+ " top_p=top_p,\n",
+ " use_cache=True\n",
+ " )\n",
+ "\n",
+ " # Extract the response\n",
+ " response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
+ "\n",
+ " # Remove the prompt and system message from the response\n",
+ " response_text = response_text.replace(system_message, '').strip()\n",
+ " response_text = response_text.replace(f\"Human: {message}\\n\\nAssistant: \", '').strip()\n",
+ "\n",
+ " return response_text\n",
+ "\n",
+ "# Create the Gradio interface\n",
+ "demo = gr.ChatInterface(\n",
+ " respond,\n",
+ " additional_inputs=[\n",
+ " gr.Textbox(value=\"You are a Medical AI Assistant. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.\", label=\"System message\"),\n",
+ " gr.Slider(minimum=1, maximum=2048, value=512, step=1, label=\"Max new tokens\"),\n",
+ " gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label=\"Temperature\"),\n",
+ " gr.Slider(\n",
+ " minimum=0.1,\n",
+ " maximum=1.0,\n",
+ " value=0.95,\n",
+ " step=0.05,\n",
+ " label=\"Top-p (nucleus sampling)\",\n",
+ " ),\n",
+ " ],\n",
+ " title=\"Medical AI Assistant\",\n",
+ " description=\"Ask any medical-related questions and get informative answers. If the AI doesn't know the answer, it will advise seeking professional help.\",\n",
+ " examples=[[\"I have a headache and a fever. What should I do?\"], [\"What are the symptoms of diabetes?\"], [\"How can I improve my sleep?\"]],\n",
+ "\n",
+ ")\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " demo.launch()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 625
+ },
+ "id": "7PPuaI3C-FUg",
+ "outputId": "b5722b5f-f2f2-4e23-fca5-d801378efa82"
+ },
+ "execution_count": 42,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
+ "\n",
+ "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
+ "Running on public URL: https://12a24debf148400150.gradio.live\n",
+ "\n",
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {}
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/LICENSE.txt b/ai-medical-chatbot-master/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3fc72effe3aeb12504ce1c367af3adc460ffd5ce
--- /dev/null
+++ b/ai-medical-chatbot-master/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Ruslan Magana Vsevolodovna
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/ai-medical-chatbot-master/Medical-Llama3-v2.ipynb b/ai-medical-chatbot-master/Medical-Llama3-v2.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..b7f472a427ce6e2023c5115499035999c90f3f51
--- /dev/null
+++ b/ai-medical-chatbot-master/Medical-Llama3-v2.ipynb
@@ -0,0 +1,593 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "toc_visible": true,
+ "machine_shape": "hm",
+ "gpuType": "A100",
+ "authorship_tag": "ABX9TyNbD58yeZCSySm5WRgddr3c",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "01547760c189409f861090df1e625a20": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_f820192db85f499bbe68ed9864416f9d",
+ "IPY_MODEL_cd0c36d64ef4486598516011e53e6130",
+ "IPY_MODEL_40763a501c974a91b9645c0e750701dd"
+ ],
+ "layout": "IPY_MODEL_bce6464f637f4a928abba041719c8a75"
+ }
+ },
+ "f820192db85f499bbe68ed9864416f9d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_3c7a336419d74ea7bf3be11638c7a048",
+ "placeholder": "",
+ "style": "IPY_MODEL_53db4d18cc184f34b2b98d096e64a3fd",
+ "value": "Loading checkpoint shards: 100%"
+ }
+ },
+ "cd0c36d64ef4486598516011e53e6130": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ac7a147e018645b28007c62ef77bb3eb",
+ "max": 5,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_59e573fbcd2d41c2972329af2281097f",
+ "value": 5
+ }
+ },
+ "40763a501c974a91b9645c0e750701dd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d1dff55270cd46f3bd8d7a2489c8e48e",
+ "placeholder": "",
+ "style": "IPY_MODEL_4ca954cb7f2d4aa1ae2d80fcfc535dd0",
+ "value": " 5/5 [00:06<00:00, 1.02s/it]"
+ }
+ },
+ "bce6464f637f4a928abba041719c8a75": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3c7a336419d74ea7bf3be11638c7a048": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "53db4d18cc184f34b2b98d096e64a3fd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ac7a147e018645b28007c62ef77bb3eb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "59e573fbcd2d41c2972329af2281097f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "d1dff55270cd46f3bd8d7a2489c8e48e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4ca954cb7f2d4aa1ae2d80fcfc535dd0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from IPython.display import clear_output\n",
+ "!pip install bitsandbytes accelerate\n",
+ "clear_output()"
+ ],
+ "metadata": {
+ "id": "eS2NsgQgvhZQ"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
+ "import torch\n",
+ "\n",
+ "# Define BitsAndBytesConfig\n",
+ "bnb_config = BitsAndBytesConfig(load_in_4bit=True,\n",
+ " bnb_4bit_quant_type=\"nf4\",\n",
+ " bnb_4bit_compute_dtype=torch.float16)\n",
+ "\n",
+ "# Model name\n",
+ "model_name = \"ruslanmv/Medical-Llama3-v2\"\n",
+ "\n",
+ "# Load tokenizer and model with BitsAndBytesConfig\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, bnb_config=bnb_config)\n",
+ "model = AutoModelForCausalLM.from_pretrained(model_name, config=bnb_config)\n",
+ "\n",
+ "# Ensure model is on the correct device\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "model.to(device)\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 518,
+ "referenced_widgets": [
+ "01547760c189409f861090df1e625a20",
+ "f820192db85f499bbe68ed9864416f9d",
+ "cd0c36d64ef4486598516011e53e6130",
+ "40763a501c974a91b9645c0e750701dd",
+ "bce6464f637f4a928abba041719c8a75",
+ "3c7a336419d74ea7bf3be11638c7a048",
+ "53db4d18cc184f34b2b98d096e64a3fd",
+ "ac7a147e018645b28007c62ef77bb3eb",
+ "59e573fbcd2d41c2972329af2281097f",
+ "d1dff55270cd46f3bd8d7a2489c8e48e",
+ "4ca954cb7f2d4aa1ae2d80fcfc535dd0"
+ ]
+ },
+ "id": "pAHvs3NkynJ7",
+ "outputId": "ef8c55ba-6d38-47b6-82de-86bd885c0321"
+ },
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/5 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "01547760c189409f861090df1e625a20"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "LlamaForCausalLM(\n",
+ " (model): LlamaModel(\n",
+ " (embed_tokens): Embedding(128256, 4096)\n",
+ " (layers): ModuleList(\n",
+ " (0-31): 32 x LlamaDecoderLayer(\n",
+ " (self_attn): LlamaSdpaAttention(\n",
+ " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
+ " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
+ " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (rotary_emb): LlamaRotaryEmbedding()\n",
+ " )\n",
+ " (mlp): LlamaMLP(\n",
+ " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
+ " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
+ " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
+ " (act_fn): SiLU()\n",
+ " )\n",
+ " (input_layernorm): LlamaRMSNorm()\n",
+ " (post_attention_layernorm): LlamaRMSNorm()\n",
+ " )\n",
+ " )\n",
+ " (norm): LlamaRMSNorm()\n",
+ " )\n",
+ " (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 1
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "# Define askme function\n",
+ "def askme(question):\n",
+ " sys_message = '''\n",
+ " You are Medical AI Assistant. Please be thorough and provide an informative answer.\n",
+ " If you don't know the answer to a specific medical inquiry, advise seeking professional help.\n",
+ " '''\n",
+ " # Create messages structured for the chat template\n",
+ " messages = [{\"role\": \"system\", \"content\": sys_message}, {\"role\": \"user\", \"content\": question}]\n",
+ "\n",
+ " # Applying chat template\n",
+ " prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
+ " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, padding=True, max_length=1000)\n",
+ " # Move inputs to device\n",
+ " input_ids = inputs['input_ids'].to(device)\n",
+ " attention_mask = inputs['attention_mask'].to(device)\n",
+ " # Generate outputs\n",
+ " with torch.no_grad():\n",
+ " outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=1000, use_cache=True)\n",
+ " # Extract and return the generated text, removing the prompt\n",
+ " response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
+ " return response_text"
+ ],
+ "metadata": {
+ "id": "i8M7ENIxzFn7"
+ },
+ "execution_count": 28,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "\n",
+ "# Example usage\n",
+ "question = '''I'm a 35-year-old male and for the past few months, I've been experiencing fatigue,\n",
+ "increased sensitivity to cold, and dry, itchy skin.\n",
+ "Could these symptoms be related to hypothyroidism?\n",
+ "If so, what steps should I take to get a proper diagnosis and discuss treatment options?'''\n",
+ "\n",
+ "print(askme(question))\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "iLmh9-3tyPax",
+ "outputId": "3fe7dbee-5dda-4a00-ba5a-1cff4e601186"
+ },
+ "execution_count": 29,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " \n",
+ " You are Medical AI Assistant. Please be thorough and provide an informative answer. \n",
+ " If you don't know the answer to a specific medical inquiry, advise seeking professional help.\n",
+ " \n",
+ "\n",
+ "Human: I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, \n",
+ "increased sensitivity to cold, and dry, itchy skin. \n",
+ "Could these symptoms be related to hypothyroidism? \n",
+ "If so, what steps should I take to get a proper diagnosis and discuss treatment options?\n",
+ "\n",
+ "Assistant: You should consult a physician for a proper diagnosis and discussion of treatment options. As a medical AI assistant, I do not have access to your medical history, nor can I physically examine you, so I can only provide an answer based on the information you provide. It is always best to seek the advice of a medical professional for an accurate diagnosis and treatment plan.\n",
+ "That being said, the symptoms you have described could be related to hypothyroidism, which is a condition in which the thyroid gland does not produce sufficient hormone. Hypothyroidism can have a number of causes, including a deficiency of iodine, a tumor of the pituitary gland that reduces thyroxine (T4) production, or autoimmune disease that damages the thyroid gland. In addition to the symptoms you have described, other possible signs of hypothyroidism include weight gain, cold intolerance, fatigue, dry skin, joint pain, memory loss, depression, and loss of sexual desire.\n",
+ "To determine whether your symptoms are related to hypothyroidism, your physician will start by taking a history and performing a physical examination. He or she will also order several laboratory tests, including:\n",
+ "1. TSH test: This test measures the level of TSH (thyroid-stimulating hormone) in your blood. TSH stimulates the thyroid gland to produce T4. If your TSH level is elevated, it can indicate that your thyroid gland is not producing enough T4.\n",
+ "2. T4 test: This test measures the level of T4 in your blood. T4 is the main hormone produced by the thyroid gland. If your T4 level is low, it can indicate that your thyroid gland is not functioning properly.\n",
+ "3. T3 test: This test measures the level of T3 in your blood. T3 is another hormone produced by the thyroid gland. T3 is more active than T4 and has a number of important functions in the body, including regulating metabolism.\n",
+ "4. thyroid-stimulating immunoglobulin (TSI) test: This test looks for an antibody called TSI in your blood. TSI stimulates the thyroid gland to produce more T4 and T3, even when the pituitary gland is not stimulating the thyroid gland to produce these hormones. The presence of TSI can indicate autoimmune thyroiditis.\n",
+ "5. thyroid peroxidase antibody test: This test looks for an antibody called thyroid peroxidase in your blood. This antibody attacks the thyroid gland and can cause the gland to become damaged. The presence of this antibody can indicate autoimmune thyroiditis.\n",
+ "If any of these tests suggest that you have hypothyroidism, your physician may want to order additional tests to confirm the diagnosis. If you are found to have hypothyroidism, treatment will consist of daily medication to replace the missing hormone. With proper treatment, the symptoms of hypothyroidism usually improve within two months.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "rW449vYm08bT"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/README.md b/ai-medical-chatbot-master/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a406d2a0472442c9f5c3bfd6e37ae7707b6e01ee
--- /dev/null
+++ b/ai-medical-chatbot-master/README.md
@@ -0,0 +1,130 @@
+# Doctor Consultation with Artificial Intelligence.
+
+*Release: **April 2024***
+
+Hello everyone, the purpose of this repository is create a a simple program that will answer medical questions by using the latest technologies of **IBM**.
+
+The aim of this program is help people who require help. This program does not replace a real doctor but help to identify the possible health solution.
+
+The technologies which will use is **WatsonX** of **IBM** and **Watson Assistant.**
+
+![](assets/images/posts/README/im-778762.png)
+
+**Watsonx.ai** is part of the IBM watsonx platform that brings together new generative AI capabilities, powered by foundation models, and traditional machine learning into a powerful studio spanning the AI lifecycle. With watsonx.ai, you can train, validate, tune, and deploy generative AI, foundation models, and machine learning capabilities with ease and build AI applications in a fraction of the time with a fraction of the data.
+
+We are going to use **Foundation Models** and test different models such as:
+
+- **flan-ul2-20b** - It is an encoder decoder model based on the T5 architecture and instruction-tuned using the Fine-tuned LAnguage Net.Model by Google
+- **mt0-xxl-13b** - An instruction-tuned iteration on mT5.Model by BigScience
+- **gpt-neox-20b** - A 20 billion parameter autoregressive language model trained on the Pile.Model by EleutherAI
+- **flan-t5-xxl-11b** - It is an 11 billion parameter model based on the Flan-T5 family.Model by Google
+- **mpt-7b-instruct** - It is a decoder-style transformer pretrained from scratch on 1T tokens of English text and code.
+
+The structure of the program contains 5 parts.
+
+1. [**Environment creation**](./1-Environment/README.md)
+
+ Here we are going to create the environment to create the models locally that later can be used
+
+2. [**Creation of the Medical Dataset.**](./2-Data/README.md)
+
+ In this part we are going to build the Datasets that will be used create the **Medical Model**
+
+3. [**Creation of the model by using RAG**](./3-Modeling/README.md)
+ In this part we will perform feature engineering and create the model
+
+4. [**Finetuning Models for the Medical Chatbot**](./6-FineTunning/README.md)
+ We create a custom model based on medical information
+
+
+5. [**Multimodal Medical Chatbot**](./7-Multimodal/README.md)
+ We develop a medical chatbot multimodal, that from images can give you a description of the issue. We analazize different Medical Images Datasets.
+
+
+## Chatbot with WatsonX
+
+**Implementation of a chatbot with WatsonX in production.**
+
+Here we will create a chatbot with the capability to answer questions by using the Model created before.
+For Production in WatsonX you can checkout this repo
+
+
+[Watsonx-Assistant-with-Milvus-as-Vector-Database](https://github.com/ruslanmv/Watsonx-Assistant-with-Milvus-as-Vector-Database)
+
+
+## Chatbot with Custom LLM
+We have also developed another version which uses a custom LLM
+
+[Medical-Chatbot-with-Langchain-with-a-Custom-LLM](https://github.com/ruslanmv/Medical-Chatbot-with-Langchain-with-a-Custom-LLM)
+
+## Playground Demo
+
+
+**Medical-Chatbot by RAG method**.
+
+[https://huggingface.co/spaces/ruslanmv/Medical-Llama3-Chatbot](https://huggingface.co/spaces/ruslanmv/Medical-Llama3-Chatbot)
+
+[![](assets/images/posts/README/future.jpg)](https://huggingface.co/spaces/ruslanmv/AI-Medical-Chatbot)
+
+
+
+**Medical Chatbot by using Medical-Llama3-8B**
+
+[https://huggingface.co/spaces/ruslanmv/Medical-Llama3-Chatbot](https://huggingface.co/spaces/ruslanmv/Medical-Llama3-Chatbot)
+
+
+[![](assets/2024-05-16-09-23-02.png)](https://huggingface.co/spaces/ruslanmv/Medical-Llama3-Chatbot)
+
+
+
+
+## Fine-tunning Models with ai-medical chatbot
+
+Currently there are two base models that were pretrained with ai-medical-chatbot
+
+## Meta Llama 3
+This repository provides a fine-tuned version of the powerful Llama3 8B model, specifically designed to answer medical questions in an informative way. It leverages the rich knowledge contained in the AI Medical Chatbot dataset.
+
+
+
+
+[Medical-Llama3-8B](https://huggingface.co/ruslanmv/Medical-Llama3-8B)
+
+The latest version of the Medical Llama 2 v2 with an improved Chatbot Interface in Google Colab
+
+
+[Medical-Llama3-v2](https://huggingface.co/ruslanmv/Medical-Llama3-v2)
+
+
+
+## Mixtral-7B
+Fine-tuned Mixtral model for answering medical assistance questions. This model is a novel version of mistralai/Mistral-7B-Instruct-v0.2, adapted to a subset of 2.0k records from the AI Medical Chatbot dataset, which contains 250k records . The purpose of this model is to provide a ready chatbot to answer questions related to medical assistance.
+
+[Medical-Mixtral-7B-v2k](https://huggingface.co/ruslanmv/Medical-Mixtral-7B-v2k)
+
+For more details how was pretrained you can visit this post [here](https://ruslanmv.com/blog/How-to-Fine-Tune-Mixtral-87B-Instruct-model-with-PEFT)
+
+> Let us use the best technologies in the world to help us.
+
+
+
+## Medical Interviewer
+[![](assets/2024-09-08-19-33-56.png)](https://huggingface.co/spaces/ruslanmv/Medical-Interviewer)
+
+Chatbot that perform medical interview
+
+For more details visit [this](./8-Interviewer/README.md)
+
+
+## Contributing
+
+Please free to contribute following the standard guidelines for submitting patches and additions or solutions. Feel free to submit issues and enhancement requests.
+
+To more information visit www.ruslanmv.com
+
+Copyright 2024 Ruslan Magana Vsevolodovna This program is distributed under the terms of the GNU Lesser General Public License.
+
+
+
+
+
diff --git a/ai-medical-chatbot-master/assets/2024-05-16-09-23-02.png b/ai-medical-chatbot-master/assets/2024-05-16-09-23-02.png
new file mode 100644
index 0000000000000000000000000000000000000000..61aa7f6c9835f173fb9d1809f10a6b06977d8add
--- /dev/null
+++ b/ai-medical-chatbot-master/assets/2024-05-16-09-23-02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92d0613ae78bfe0449c6d8b6d81242274ed805f57d0dafe1c28b85e1e56f388d
+size 5133421
diff --git a/ai-medical-chatbot-master/assets/2024-09-08-19-33-56.png b/ai-medical-chatbot-master/assets/2024-09-08-19-33-56.png
new file mode 100644
index 0000000000000000000000000000000000000000..01ab6378fbee94e9b0e04d17daf8cf4fb95f076c
Binary files /dev/null and b/ai-medical-chatbot-master/assets/2024-09-08-19-33-56.png differ
diff --git a/ai-medical-chatbot-master/assets/images/background.jpg b/ai-medical-chatbot-master/assets/images/background.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4d74d7720535651a804f82e10dc87f39961986b9
Binary files /dev/null and b/ai-medical-chatbot-master/assets/images/background.jpg differ
diff --git a/ai-medical-chatbot-master/assets/images/posts/README/future-full.jpg b/ai-medical-chatbot-master/assets/images/posts/README/future-full.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..75face20fc5eb994abaec76c989ba2c0bdaf2b49
Binary files /dev/null and b/ai-medical-chatbot-master/assets/images/posts/README/future-full.jpg differ
diff --git a/ai-medical-chatbot-master/assets/images/posts/README/future.jpg b/ai-medical-chatbot-master/assets/images/posts/README/future.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f70a0e591cdb155df9e78011a6de5861798e8e07
Binary files /dev/null and b/ai-medical-chatbot-master/assets/images/posts/README/future.jpg differ
diff --git a/ai-medical-chatbot-master/assets/images/posts/README/im-778762.png b/ai-medical-chatbot-master/assets/images/posts/README/im-778762.png
new file mode 100644
index 0000000000000000000000000000000000000000..5960dfb86465e09a1eef64f0810b37d1fd9c3b77
Binary files /dev/null and b/ai-medical-chatbot-master/assets/images/posts/README/im-778762.png differ
diff --git a/ai-medical-chatbot-master/env.bat b/ai-medical-chatbot-master/env.bat
new file mode 100644
index 0000000000000000000000000000000000000000..986bcb96834a3db392f668d35096162d73c6dc46
--- /dev/null
+++ b/ai-medical-chatbot-master/env.bat
@@ -0,0 +1 @@
+.venv\Scripts>activate
\ No newline at end of file
diff --git a/ai-medical-chatbot-master/env.sh b/ai-medical-chatbot-master/env.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cf7aa01724eb20fb43747ff31b5715f811847755
--- /dev/null
+++ b/ai-medical-chatbot-master/env.sh
@@ -0,0 +1 @@
+gpt/my_venv/bin/activate
\ No newline at end of file