|
--- |
|
base_model: google/gemma-2b |
|
library_name: peft |
|
license: apache-2.0 |
|
datasets: |
|
- Rahulholla/stock-analysis |
|
language: |
|
- en |
|
pipeline_tag: question-answering |
|
tags: |
|
- finance |
|
--- |
|
# Model Card for Model ID |
|
A Gemma-2b finetuned LoRA trained on science Q&A |
|
- **Developed by:** Venkat |
|
|
|
<!-- Provide the basic links for the model. --> |
|
|
|
|
|
## How to Get Started with the Model |
|
``` |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
from peft import PeftModel |
|
from typing import Optional |
|
import time |
|
import os |
|
|
|
def generate_prompt(input_text: str, instruction: Optional[str] = None) -> str: |
|
text = f"### Question: {input_text}\n\n### Answer: " |
|
if instruction: |
|
text = f"### Instruction: {instruction}\n\n{text}" |
|
return text |
|
|
|
huggingface_token = os.environ.get('HUGGINGFACE_TOKEN') |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=huggingface_token) |
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", token=huggingface_token) |
|
|
|
lora_model = PeftModel.from_pretrained(base_model, "vdpappu/lora_stock_analysis") |
|
merged_model = lora_model.merge_and_unload() |
|
|
|
eos_token = '<eos>' |
|
eos_token_id = tokenizer.encode(eos_token, add_special_tokens=False)[-1] |
|
|
|
generation_config = GenerationConfig( |
|
eos_token_id=tokenizer.eos_token_id, |
|
min_length=5, |
|
max_length=200, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
top_k=50, |
|
repetition_penalty=1.5, |
|
no_repeat_ngram_size=3, |
|
early_stopping=True |
|
) |
|
|
|
question = """Assume the role as a seasoned stock option analyst with a strong track record in dissecting intricate option data to discern valuable |
|
insights into stock sentiment. Proficient in utilizing advanced statistical models and data visualization techniques to forecast |
|
market trends and make informed trading decisions. Adept at interpreting option Greeks, implied volatility, .. """ |
|
prompt = generate_prompt(input_text=question) |
|
|
|
with torch.no_grad(): |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
output = merged_model.generate(**inputs, generation_config=generation_config) |
|
response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
print(response) |
|
``` |
|
|
|
- PEFT 0.12.0 |