Haseeb javed commited on
Commit
db9fe30
·
1 Parent(s): a996985

runtime error fixes

Browse files
Files changed (1) hide show
  1. app.py +22 -23
app.py CHANGED
@@ -5,45 +5,44 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import logging
6
  import os
7
 
8
- MIN_TRANSFORMERS_VERSION = '4.25.1'
9
-
10
- # Check transformers version
11
- import transformers
12
- assert transformers.__version__ >= MIN_TRANSFORMERS_VERSION, f'Please upgrade transformers to version {MIN_TRANSFORMERS_VERSION} or higher.'
13
 
14
- # Initialize tokenizer and model from local directory
15
- model_dir = "hsb06/toghetherAi-model"
16
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
- model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
 
 
18
 
 
19
  app = Flask(__name__)
20
- CORS(app) # Enable CORS
21
-
22
- logging.basicConfig(level=logging.DEBUG)
23
 
24
  def generate_response(prompt):
25
- inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
26
  input_length = inputs.input_ids.shape[1]
27
  outputs = model.generate(
28
- **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True
 
 
 
 
 
 
29
  )
30
  token = outputs.sequences[0, input_length:]
31
- output_str = tokenizer.decode(token, skip_special_tokens=True)
32
- return output_str
33
 
34
- @app.route('/chat', methods=['POST'])
35
  def chat():
36
- logging.debug("Received a POST request")
37
  data = request.json
38
- logging.debug(f"Request data: {data}")
39
  user_input = data.get("message", "")
40
  prompt = f"<human>: {user_input}\n<bot>:"
41
  response = generate_response(prompt)
42
- logging.debug(f"Generated response: {response}")
43
  return jsonify({"response": response})
44
 
45
  if __name__ == "__main__":
46
- # Get the port from environment variable or default to 5000
47
- port = int(os.getenv("PORT", 5000))
48
  logging.info(f"Starting Flask app on port {port}")
49
- app.run(debug=True, host="0.0.0.0", port=port)
 
5
  import logging
6
  import os
7
 
8
+ # Initialize logger
9
+ logging.basicConfig(level=logging.DEBUG)
 
 
 
10
 
11
+ # Load tokenizer and model
12
+ logging.info("Loading model...")
13
+ model_repo = "hsb06/toghetherAi-model"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=True)
15
+ model = AutoModelForCausalLM.from_pretrained(model_repo, torch_dtype=torch.float16, use_auth_token=True).to("cuda" if torch.cuda.is_available() else "cpu")
16
+ logging.info("Model loaded successfully.")
17
 
18
+ # Initialize Flask app
19
  app = Flask(__name__)
20
+ CORS(app)
 
 
21
 
22
  def generate_response(prompt):
23
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
24
  input_length = inputs.input_ids.shape[1]
25
  outputs = model.generate(
26
+ **inputs,
27
+ max_new_tokens=128,
28
+ do_sample=True,
29
+ temperature=0.7,
30
+ top_p=0.7,
31
+ top_k=50,
32
+ return_dict_in_generate=True,
33
  )
34
  token = outputs.sequences[0, input_length:]
35
+ return tokenizer.decode(token, skip_special_tokens=True)
 
36
 
37
+ @app.route("/chat", methods=["POST"])
38
  def chat():
 
39
  data = request.json
 
40
  user_input = data.get("message", "")
41
  prompt = f"<human>: {user_input}\n<bot>:"
42
  response = generate_response(prompt)
 
43
  return jsonify({"response": response})
44
 
45
  if __name__ == "__main__":
46
+ port = int(os.getenv("PORT", 7860)) # Default to 7860
 
47
  logging.info(f"Starting Flask app on port {port}")
48
+ app.run(debug=False, host="0.0.0.0", port=port)