mew77 commited on
Commit
29fabcc
·
verified ·
1 Parent(s): 55c1a89

Update hf_model.py

Browse files
Files changed (1) hide show
  1. hf_model.py +11 -1
hf_model.py CHANGED
@@ -23,7 +23,17 @@ class HFModel:
23
  def generate_response(self, input_text, max_length=100, skip_special_tokens=True):
24
  try:
25
  inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
26
- outputs = self.model.generate(**inputs, max_length=max_length)
 
 
 
 
 
 
 
 
 
 
27
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip()
28
  log_info(f"Generated Response: {response}")
29
  return response
 
23
  def generate_response(self, input_text, max_length=100, skip_special_tokens=True):
24
  try:
25
  inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
26
+ outputs = self.model.generate(
27
+ **inputs,
28
+ max_length=max_length,
29
+ pad_token_id=self.tokenizer.eos_token_id, # Ensure proper padding
30
+ do_sample=True, # Enable sampling for more diverse outputs
31
+ top_k=50, # Limit sampling to top-k tokens
32
+ top_p=0.95, # Use nucleus sampling
33
+ temperature=0.7, # Control randomness
34
+ )
35
+ #inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
36
+ #outputs = self.model.generate(**inputs, max_length=max_length)
37
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip()
38
  log_info(f"Generated Response: {response}")
39
  return response