Manju017 commited on
Commit
bce1941
·
verified ·
1 Parent(s): 1101f80

code to include the necessary imports and settings to use the Accelerate library effectively

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -1,30 +1,37 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
 
4
  model_name = "ai4bharat/Airavata"
5
 
6
- # Load the model in 8-bit precision to reduce memory usage
 
 
 
 
 
 
7
  model = AutoModelForCausalLM.from_pretrained(
8
  model_name,
9
- device_map="auto",
10
- load_in_8bit=True
11
  )
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
 
14
- def generate_text(prompt, max_length):
 
15
  inputs = tokenizer(prompt, return_tensors="pt")
16
- outputs = model.generate(**inputs, max_length=max_length)
17
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
18
 
 
19
  interface = gr.Interface(
20
  fn=generate_text,
21
- inputs=[
22
- gr.inputs.Textbox(label="Enter your prompt"),
23
- gr.inputs.Slider(10, 100, step=10, label="Max length")
24
- ],
25
  outputs="text",
26
  title="Airavata Text Generation Model",
27
- description="Generate text in Indic languages using the Airavata model."
28
  )
29
 
 
30
  interface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from accelerate import infer_auto_device_map
4
 
5
+ # Load the model name
6
  model_name = "ai4bharat/Airavata"
7
 
8
+ # Load the tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+
11
+ # Automatically determine the device map
12
+ device_map = infer_auto_device_map(model_name)
13
+
14
+ # Load the model with the device map
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name,
17
+ device_map=device_map,
18
+ load_in_8bit=True # Use 8-bit precision for reduced memory usage
19
  )
 
20
 
21
+ # Define the inference function
22
+ def generate_text(prompt):
23
  inputs = tokenizer(prompt, return_tensors="pt")
24
+ outputs = model.generate(**inputs)
25
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
26
 
27
+ # Create the Gradio interface
28
  interface = gr.Interface(
29
  fn=generate_text,
30
+ inputs="text",
 
 
 
31
  outputs="text",
32
  title="Airavata Text Generation Model",
33
+ description="This is the AI4Bharat Airavata model for text generation in Indic languages."
34
  )
35
 
36
+ # Launch the interface
37
  interface.launch()