File size: 2,768 Bytes
894b24d
 
 
 
 
 
 
 
 
 
ed49033
1033026
894b24d
 
 
 
 
 
 
625b3d8
04f4e3d
 
894b24d
 
 
 
 
 
 
 
 
 
 
 
 
76268a8
894b24d
cb8fae7
894b24d
 
 
 
 
 
 
 
 
 
 
15c9ed1
3f2b796
bca5c68
3f2b796
bca5c68
894b24d
 
 
 
 
15c9ed1
894b24d
 
d3f1526
 
894b24d
d3f1526
d4df546
ecd91ef
72fee02
 
 
 
57aa9a4
72fee02
953205d
1033026
72fee02
 
953205d
ed49033
 
57aa9a4
 
 
 
 
 
 
16e6449
72fee02
625b3d8
cc4118d
d4df546
52ded96
894b24d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from metrics import calc_metrics
import gradio as gr
from openai import OpenAI
import os

from transformers import pipeline
# from dotenv import load_dotenv, find_dotenv
import huggingface_hub
import json
from evaluate_data import store_sample_data, get_metrics_trf
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

hf_token= os.environ['HF_TOKEN']
huggingface_hub.login(hf_token)

pipe = pipeline("token-classification", model="elshehawy/finer-ord-transformers", aggregation_strategy="first")


# llm_model = 'gpt-3.5-turbo-0125'
# llm_model = 'gpt-4-0125-preview'
llm_model = 'gpt-3.5-turbo-0301'
# openai.api_key = os.environ['OPENAI_API_KEY']

client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
)


def get_completion(prompt, model=llm_model):
    messages = [{"role": "user", "content": prompt}]
    response = client.chat.completions.create(
        messages=messages,
        model=model,
        temperature=0,
#        response_format={"type": "json_object"}
    )
#    print(response.choices[0].message.content)
    return response.choices[0].message.content


def find_orgs_gpt(sentence):
    prompt = f"""
    In context of named entity recognition (NER), find all organizations in the text delimited by triple backticks.
    
    text:
    ```
    {sentence}
    ```
    Your output should be a a json object that containes the extracted organizations.
    Output example 1:
    {{\"Organizations\": [\"Organization 1\", \"Organization 2\", \"Organization 3\"]}}
    Output example 2:
    {{\"Organizations\": []}}
    """
    
    sent_orgs_str = get_completion(prompt)
    sent_orgs = json.loads(sent_orgs_str)
    
    return sent_orgs['Organizations']


example = """
My latest exclusive for The Hill : Conservative frustration over Republican efforts to force a House vote on reauthorizing the Export - Import Bank boiled over Wednesday during a contentious GOP meeting.

"""
def find_orgs(uploaded_file):
    print(type(uploaded_file))
    uploaded_data = json.loads(uploaded_file)
    all_metrics = {}

    sample_data = store_sample_data(uploaded_data)

    gpt_orgs, true_orgs = [], []
    
    for sent in tqdm(sample_data):
        gpt_orgs.append(find_orgs_gpt(sent['text']))
        true_orgs.append(sent['orgs'])

    sim_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    all_metrics['gpt'] = calc_metrics(true_orgs, gpt_orgs, sim_model, threshold=0.85)        
    print(all_metrics)
    
    all_metrics['trf'] = get_metrics_trf(uploaded_data)
    

    
    
    print(all_metrics)
    return all_metrics

upload_btn = gr.UploadButton(label='Upload a json file.', type='binary')

iface = gr.Interface(fn=find_orgs, inputs=upload_btn, outputs="text")
iface.launch(share=True)