|
import os |
|
import gradio as gr |
|
from gradio import FlaggingCallback |
|
from gradio.components import IOComponent |
|
|
|
from transformers import pipeline |
|
|
|
from typing import List, Optional, Any |
|
|
|
import argilla as rg |
|
|
|
import os |
|
|
|
|
|
|
|
nlp = pipeline("ner", model="deprem-ml/deprem-ner") |
|
|
|
examples = [ |
|
["Lütfen yardım Akevler mahallesi Rüzgar sokak Tuncay apartmanı zemin kat Antakya akrabalarım göçük altında #hatay #Afad"] |
|
] |
|
|
|
def create_record(input_text): |
|
|
|
predictions = nlp(input_text, aggregation_strategy="first") |
|
|
|
|
|
prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions] |
|
|
|
|
|
batch_encoding = nlp.tokenizer(input_text) |
|
word_ids = sorted(set(batch_encoding.word_ids()) - {None}) |
|
words = [] |
|
for word_id in word_ids: |
|
char_span = batch_encoding.word_to_chars(word_id) |
|
words.append(input_text[char_span.start:char_span.end]) |
|
|
|
|
|
record = rg.TokenClassificationRecord( |
|
text=input_text, |
|
tokens=words, |
|
prediction=prediction, |
|
prediction_agent="deprem-ml/deprem-ner", |
|
) |
|
print(record) |
|
return record |
|
|
|
class ArgillaLogger(FlaggingCallback): |
|
def __init__(self, api_url, api_key, dataset_name): |
|
rg.init(api_url=api_url, api_key=api_key) |
|
self.dataset_name = dataset_name |
|
def setup(self, components: List[IOComponent], flagging_dir: str): |
|
pass |
|
def flag( |
|
self, |
|
flag_data: List[Any], |
|
flag_option: Optional[str] = None, |
|
flag_index: Optional[int] = None, |
|
username: Optional[str] = None, |
|
) -> int: |
|
text = flag_data[0] |
|
inference = flag_data[1] |
|
rg.log(name=self.dataset_name, records=create_record(text)) |
|
|
|
|
|
|
|
gr.Interface.load( |
|
"models/deprem-ml/deprem-ner", |
|
examples=examples, |
|
allow_flagging="manual", |
|
flagging_callback=ArgillaLogger( |
|
api_url="https://merve-argilla.hf.space", |
|
api_key=os.getenv("TEAM_API_KEY"), |
|
dataset_name="ner-flags" |
|
), |
|
flagging_options=["Correct", "Incorrect", "Ambiguous"] |
|
).launch() |