File size: 4,256 Bytes
26bb9d0
 
 
63ad984
 
26bb9d0
 
 
 
 
fb8b8c3
c1ded61
26bb9d0
 
 
 
 
bae1d90
26bb9d0
 
 
6370d36
26bb9d0
 
 
 
0005f3b
 
26bb9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
992e802
26bb9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992e802
 
 
 
 
 
63ad984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26bb9d0
 
 
566b50e
 
 
1f04ccd
 
 
 
26bb9d0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from threading import Thread
from typing import Iterator
import requests
import json
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

HF_TOKEN = "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"  ##Llama Write Token

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

DESCRIPTION = """\
# Llama. Protected. With Protecto.
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU. Please enable GPU</p>"


if torch.cuda.is_available():
    model_id = "meta-llama/Llama-2-7b-chat-hf"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
    tokenizer.use_default_system_prompt = False

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []

    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        # Mask the output here
        masked_output = mask_with_protecto("".join(outputs))
        # yield "".join(outputs)
        yield masked_output

    
def mask_with_protecto(text_for_prompt):
    mask_request_url = "https://trial.protecto.ai/api/vault/mask"
    headers = {
        "Content-Type": "application/json; charset=utf-8",
        "Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3N1ZXIiOiJQcm90ZWN0byIsImV4cGlyYXRpb25fZGF0ZSI6IjIwMjMtMTEtMDQiLCJwZXJtaXNzaW9ucyI6WyJyZWFkIiwid3JpdGUiXSwidXNlcl9uYW1lIjoiZGlwYXlhbkBjb2V1c2xlYXJuaW5nLmNvbSIsImRiX25hbWUiOiJwcm90ZWN0b19jb2V1c2xlYXJuaW5nX25ydG1mYmFrIiwiaGFzaGVkX3Bhc3N3b3JkIjoiMjIyMTI2ZWNiZTlkZTRmNWJlODdiY2QyYWFlZWRlM2FmNDc5MzMxZmNhOTUxMWU0MDRiNzkxNDM1MGI4MWUyYiJ9.DeIK00NuhM51lRwWdnUXuQSBA1aBn5AQ8qM3pIeM01U"
    }
    mask_input = {
    "mask": [
        {
            "value": text_for_prompt
        }
    ]
    }
    response = requests.put(mask_request_url, headers=headers, json=mask_input)
    if response.status_code == 200:
      # Parse the masked result from the API response and format it for display
      masked_result = response.json()
      final_result = json.dumps(masked_result, indent=4)
      return(masked_result["data"][0]["token_value"])
    else:
      # Print an error message if the API request was not successful.
      print(f"Failed to get a successful response. Status Code: {response.status_code}")
      print(response.text)

chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
    ],
    retry_btn=None, 
    stop_btn=None, 
    undo_btn=None, 
    clear_btn=None,
)
with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()