Test task
For model inference run following
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import PeftModel
seed_value = 42
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
model_name = "lmsys/vicuna-7b-v1.5"
lora_name = 'AlexWortega/PaltaTest'
tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=1024)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(
model,
lora_name,
torch_dtype=torch.float16
)
model.eval()
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 ).to('cpu')
model = PeftModel.from_pretrained(model, path_adapter)
model.to(device)
model.eval()
def process_output(i, o):
"""
Simple output processing
"""
if isinstance(o, list):
return [seq.split('A:')[1] for seq in o]
elif isinstance(o, str):
return o.split('A:')[1]
else:
return "Unsupported data type. Please provide a list or a string."
def generate_seqs(q, k=2):
q = 'Q:'+ q + 'A:'
tokens = tokenizer.encode(q, return_tensors='pt').to(device)
g = model.generate(input_ids=tokens)
generated_sequences = tokenizer.batch_decode(g, skip_special_tokens=True)
return generated_sequences
q = """Given a weather description in plain text, rewrite it in a different style
```The weather is sunny and the temperature is 20 degrees. The wind is blowing at 10 km/h.
Citizens are advised to go out and enjoy the weather. The weather is expected to be sunny tomorrow.
And the following style: "Angry weatherman" """
s = generate_seqs(q=q) s = process_output(q,s) print(s[0])#
should output something like these
"""
Angry weatherman: "The weather is sunny and the temperature is 20 degrees. The wind is blowing at 10 km/h.
Citizens are advised to stay indoors and avoid going out. The weather is expected to be sunny tomorrow.
"""
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no pipeline_tag.