Haseeb javed commited on
Commit
344f451
·
1 Parent(s): 1f613fd

first commit

Browse files
Files changed (2) hide show
  1. app.py +49 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import logging
6
+ import os
7
+
8
+ MIN_TRANSFORMERS_VERSION = '4.25.1'
9
+
10
+ # Check transformers version
11
+ import transformers
12
+ assert transformers.__version__ >= MIN_TRANSFORMERS_VERSION, f'Please upgrade transformers to version {MIN_TRANSFORMERS_VERSION} or higher.'
13
+
14
+ # Initialize tokenizer and model from local directory
15
+ model_dir = "./"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
18
+
19
+ app = Flask(__name__)
20
+ CORS(app) # Enable CORS
21
+
22
+ logging.basicConfig(level=logging.DEBUG)
23
+
24
+ def generate_response(prompt):
25
+ inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
26
+ input_length = inputs.input_ids.shape[1]
27
+ outputs = model.generate(
28
+ **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True
29
+ )
30
+ token = outputs.sequences[0, input_length:]
31
+ output_str = tokenizer.decode(token, skip_special_tokens=True)
32
+ return output_str
33
+
34
+ @app.route('/chat', methods=['POST'])
35
+ def chat():
36
+ logging.debug("Received a POST request")
37
+ data = request.json
38
+ logging.debug(f"Request data: {data}")
39
+ user_input = data.get("message", "")
40
+ prompt = f"<human>: {user_input}\n<bot>:"
41
+ response = generate_response(prompt)
42
+ logging.debug(f"Generated response: {response}")
43
+ return jsonify({"response": response})
44
+
45
+ if __name__ == "__main__":
46
+ # Get the port from environment variable or default to 5000
47
+ port = int(os.getenv("PORT", 5000))
48
+ logging.info(f"Starting Flask app on port {port}")
49
+ app.run(debug=True, host="0.0.0.0", port=port)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ torch
4
+ transformers