Spaces:
Sleeping
Sleeping
import streamlit as st | |
import uuid | |
import sys | |
import requests | |
from peft import * | |
import bitsandbytes as bnb | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import transformers | |
from datasets import load_dataset | |
from huggingface_hub import notebook_login | |
from peft import ( | |
LoraConfig, | |
PeftConfig, | |
get_peft_model, | |
prepare_model_for_kbit_training, | |
) | |
from transformers import ( | |
AutoConfig, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
) | |
USER_ICON = "images/user-icon.png" | |
AI_ICON = "images/ai-icon.png" | |
MAX_HISTORY_LENGTH = 5 | |
if 'user_id' in st.session_state: | |
user_id = st.session_state['user_id'] | |
else: | |
user_id = str(uuid.uuid4()) | |
st.session_state['user_id'] = user_id | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if "chats" not in st.session_state: | |
st.session_state.chats = [ | |
{ | |
'id': 0, | |
'question': '', | |
'answer': '' | |
} | |
] | |
if "questions" not in st.session_state: | |
st.session_state.questions = [] | |
if "answers" not in st.session_state: | |
st.session_state.answers = [] | |
if "input" not in st.session_state: | |
st.session_state.input = "" | |
st.markdown(""" | |
<style> | |
.block-container { | |
padding-top: 32px; | |
padding-bottom: 32px; | |
padding-left: 0; | |
padding-right: 0; | |
} | |
.element-container img { | |
background-color: #000000; | |
} | |
.main-header { | |
font-size: 24px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def write_top_bar(): | |
col1, col2, col3 = st.columns([1,10,2]) | |
with col1: | |
st.image(AI_ICON, use_column_width='always') | |
with col2: | |
header = "Cogwise Intelligent Assistant" | |
st.write(f"<h3 class='main-header'>{header}</h3>", unsafe_allow_html=True) | |
with col3: | |
clear = st.button("Clear Chat") | |
return clear | |
clear = write_top_bar() | |
if clear: | |
st.session_state.questions = [] | |
st.session_state.answers = [] | |
st.session_state.input = "" | |
st.session_state["chat_history"] = [] | |
def handle_input(): | |
input = st.session_state.input | |
question_with_id = { | |
'question': input, | |
'id': len(st.session_state.questions) | |
} | |
st.session_state.questions.append(question_with_id) | |
chat_history = st.session_state["chat_history"] | |
if len(chat_history) == MAX_HISTORY_LENGTH: | |
chat_history = chat_history[:-1] | |
# api_url = "https://9pl792yjf9.execute-api.us-east-1.amazonaws.com/beta/chatcogwise" | |
# api_request_data = {"question": input, "session": user_id} | |
# api_response = requests.post(api_url, json=api_request_data) | |
# result = api_response.json() | |
# answer = result['answer'] | |
# !pip install -Uqqq pip --progress-bar off | |
# !pip install -qqq bitsandbytes == 0.39.0 | |
# !pip install -qqqtorch --2.0.1 --progress-bar off | |
# !pip install -qqq -U git + https://github.com/huggingface/transformers.git@e03a9cc --progress-bar off | |
# !pip install -qqq -U git + https://github.com/huggingface/peft.git@42a184f --progress-bar off | |
# !pip install -qqq -U git + https://github.com/huggingface/accelerate.git@c9fbb71 --progress-bar off | |
# !pip install -qqq datasets == 2.12.0 --progress-bar off | |
# !pip install -qqq loralib == 0.1.1 --progress-bar off | |
# !pip install einops | |
import os | |
# from pprint import pprint | |
# import json | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
# notebook_login() | |
# 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 | |
# """### Load dataset""" | |
from datasets import load_dataset | |
dataset_name = "nisaar/Lawyer_GPT_India" | |
# dataset_name = "patrick11434/TEST_LLM_DATASET" | |
dataset = load_dataset(dataset_name, split="train") | |
# """## Load adapters from the Hub | |
# You can also directly load adapters from the Hub using the commands below: | |
# """ | |
# change peft_model_id | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
load_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
peft_model_id = "nisaar/falcon7b-Indian_Law_150Prompts" | |
config = PeftConfig.from_pretrained(peft_model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
return_dict=True, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = PeftModel.from_pretrained(model, peft_model_id) | |
"""## Inference | |
You can then directly use the trained model or the model that you have loaded from the π€ Hub for inference as you would do it usually in `transformers`. | |
""" | |
generation_config = model.generation_config | |
generation_config.max_new_tokens = 200 | |
generation_config_temperature = 1 | |
generation_config.top_p = 0.7 | |
generation_config.num_return_sequences = 1 | |
generation_config.pad_token_id = tokenizer.eos_token_id | |
generation_config_eod_token_id = tokenizer.eos_token_id | |
DEVICE = "cuda:0" | |
# Commented out IPython magic to ensure Python compatibility. | |
# %%time | |
# prompt = f""" | |
# <human>: Who appoints the Chief Justice of India? | |
# <assistant>: | |
# """.strip() | |
# | |
# encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
# with torch.inference_mode(): | |
# outputs = model.generate( | |
# input_ids=encoding.attention_mask, | |
# generation_config=generation_config, | |
# ) | |
# print(tokenizer.decode(outputs[0],skip_special_tokens=True)) | |
def generate_response(question: str) -> str: | |
prompt = f""" | |
<human>: {question} | |
<assistant>: | |
""".strip() | |
encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
input_ids=encoding.input_ids, | |
attention_mask=encoding.attention_mask, | |
generation_config=generation_config, | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
assistant_start = '<assistant>:' | |
response_start = response.find(assistant_start) | |
return response[response_start + len(assistant_start):].strip() | |
# prompt = "Debate the merits and demerits of introducing simultaneous elections in India?" | |
prompt=input | |
answer=print(generate_response(prompt)) | |
# answer='Yes' | |
chat_history.append((input, answer)) | |
st.session_state.answers.append({ | |
'answer': answer, | |
'id': len(st.session_state.questions) | |
}) | |
st.session_state.input = "" | |
def write_user_message(md): | |
col1, col2 = st.columns([1,12]) | |
with col1: | |
st.image(USER_ICON, use_column_width='always') | |
with col2: | |
st.warning(md['question']) | |
def render_answer(answer): | |
col1, col2 = st.columns([1,12]) | |
with col1: | |
st.image(AI_ICON, use_column_width='always') | |
with col2: | |
st.info(answer) | |
def write_chat_message(md, q): | |
chat = st.container() | |
with chat: | |
render_answer(md['answer']) | |
with st.container(): | |
for (q, a) in zip(st.session_state.questions, st.session_state.answers): | |
write_user_message(q) | |
write_chat_message(a, q) | |
st.markdown('---') | |
input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input) |