patent_gen / app.py
ghus75
update prompts: plastic bottle strap
34c208b
#import os
#import gradio as gr
#HF_Hub_API_token = os.environ.get('HF_Hub_API_token', None)
#import gradio as gr
#demo = gr.load("gaelhuser/patent_gen_prv", hf_token=HF_Hub_API_token, src="spaces")
#demo.launch()
import gradio as gr
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.memory import ConversationBufferWindowMemory
import os
HF_Hub_API_token = os.environ.get('HF_Hub_API_token', None)
github_token = os.environ.get('github_token', None)
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
task="text-generation",
max_new_tokens=2048,
do_sample=False,
repetition_penalty=1.03,
huggingfacehub_api_token=HF_Hub_API_token
)
memory = ConversationBufferMemory()
conversation = ConversationChain(llm=llm, verbose=False, memory=memory)
# import remote
from git import Repo
if not os.path.exists('repo_directory'):
Repo.clone_from(f'https://ghus75:{github_token}@github.com/ghus75/patent_gen.git', 'repo_directory' )
from repo_directory.gen_prompts import *
patent_draft = [] # shared between all users !!
trim = False#True
fraction = .15
def trim_output(response):
trimmed_output = ''
n_paragraphs = len(response.split('\n'))
if n_paragraphs > 0:
trimmed_output += response.split('\n')[0]
if n_paragraphs > 1:
other_paragraphs = response.split('\n')[1:]
others = []
for parag in other_paragraphs:
if len(parag.split(' ')) > 1:
n_words = len(parag.split(' '))
split = int(fraction*n_words)
remaining = n_words - split
others.append(' '.join(parag.split(' ')[:split]) + f" (...) [ {remaining} words left ]")
trimmed_output += '\n' + '\n'.join(others)
return trimmed_output
def gen_intro(invention_title, progress=gr.Progress()):
# si on a pas déjà cliqué sur ce bouton
if sum(['** Introduction: **' in p for p in patent_draft]) == 0:
prompt1 = prompt1a + invention_title + prompt1b
for i in progress.tqdm(range(1)):
if len(invention_title) > 0:
response = conversation.predict(input = prompt1)
if trim:
response = trim_output(response)
patent_draft.append("** Introduction: **\n" + response)
return '\n'.join(patent_draft)
# sinon retourne ce qui est déjà affiché
else: return '\n'.join(patent_draft)
def gen_bckgnd(current_sota, disadvantage, current_objective, progress=gr.Progress()):
if sum(['** Background: **' in p for p in patent_draft]) == 0:
prompt2 = prompt2a + current_sota + prompt2b + disadvantage + prompt2c + current_objective + prompt2_end
for i in progress.tqdm(range(1)):
if ((len(current_sota)) > 0 or (len(disadvantage) > 0) or len(current_objective) > 0):
response = conversation.predict(input = prompt2)
if trim:
response = trim_output(response)
patent_draft.append("\n ** Background: **\n" + response)
return '\n'.join(patent_draft)
else: return '\n'.join(patent_draft)
def gen_claim1(claim1, tech_adv, progress=gr.Progress()):
if sum(['** Technical description: **' in p for p in patent_draft]) == 0:
prompt3 = prompt3a + claim1 + prompt3b + tech_adv + prompt3c
for i in progress.tqdm(range(1)):
if ((len(claim1)) > 0 or (len(tech_adv) > 0)):
response = conversation.predict(input = prompt3)
if trim:
response = trim_output(response)
patent_draft.append("\n ** Technical description: **\n" + response)
return '\n'.join(patent_draft)
else: return '\n'.join(patent_draft)
def gen_dept_claims(dependent_claim, tech_effects, progress=gr.Progress()):
if sum(['** Technical description for dependent claims: **' in p for p in patent_draft]) == 0:
prompt4 = prompt4a + dependent_claim + prompt4b + tech_effects
for i in progress.tqdm(range(1)):
if len(dependent_claim) > 0 :
response = conversation.predict(input = prompt4)
if trim:
response = trim_output(response)
patent_draft.append("\n ** Technical description for dependent claims: **\n" + response)
return '\n'.join(patent_draft)
return '\n'.join(patent_draft)
else: return '\n'.join(patent_draft)
out = gr.Textbox(label="Patent draft:", show_label=True)
def clear():
global patent_draft
patent_draft = []
conversation.memory.clear()
return [None, None, None, None, None, None, None, None, None]
with gr.Blocks() as demo:
gr.Markdown(
"""<center><h2>Générateur de description d'invention</center></h2>
<br>
Cette application est un exemple basique de génération de description d'invention à partir d'informations sommaires
fournies par l'utilisateur.<br>
Ces informations sont combinées avec des instructions pré-définies puis passées au modèle de langage Mistral-7B pour générer une description d'invention.<br>""")
gr.Markdown(
"""
<h3>Notice d'utilisation:</h3>
- Pour chaque onglet "Introduction", "Background"... renseigner les champs (au besoin cliquer sur les exemples proposés en dessous de chaque zone de texte)
<br>- Valider chaque étape avec les boutons "Add introduction", "Add background", etc. A chaque validation, le champ "Patent draft" est mis à jour.
<br>- Le bouton "Reset draft" efface toutes les données saisies et remet la conversation à zero.<br>
<br><br> <b>Attention !</b> Cette application ne sert qu'à illustrer un exemple d'application d'un modèle de langage pour l'aide à la rédaction de brevet.
<br>Le modèle utilisé est hébergé sur le site huggingface.co : aucune confidentialité n'est garantie !
""")
with gr.Row():
with gr.Tab("Introduction"):
invention_title = gr.Textbox(label="Enter invention title", show_label=True)
examples=gr.Examples(examples=[example_title], inputs=[invention_title])
btn = gr.Button("Add introduction")
btn.click(fn=gen_intro, inputs=[invention_title], outputs=out)
with gr.Tab("Background"):
current_sota = gr.Textbox(label="The current state of the art is the following:", show_label=True)
examples=gr.Examples(examples=[example_SOTA], inputs=[current_sota])
current_disadvantage = gr.Textbox(label="The disadvantage of the current state of the art is the following:", show_label=True)
examples=gr.Examples(examples=[example_disadvantage], inputs=[current_disadvantage])
current_objective = gr.Textbox(label="The Objective is the following::", show_label=True)
examples=gr.Examples(examples=[example_objective], inputs=[current_objective])
btn = gr.Button("Add invention background")
btn.click(fn=gen_bckgnd, inputs=[current_sota, current_disadvantage, current_objective], outputs=out)
with gr.Tab("Principal claim"):
principal_claim = gr.Textbox(label="The principal claim is the following:", show_label=True)
examples=gr.Examples(examples=[example_claim1], inputs=[principal_claim])
tech_adv = gr.Textbox(label="The technical advantage is the following:", show_label=True)
examples=gr.Examples(examples=[example_tech_adv], inputs=[tech_adv])
btn = gr.Button("Add description from principal claim")
btn.click(fn=gen_claim1, inputs=[principal_claim, tech_adv], outputs=out)
with gr.Tab("Dependent claim"):
dependent_claim = gr.Textbox(label="The dependent claims are the following:", show_label=True)
examples=gr.Examples(examples=[example_dep_claim], inputs=[dependent_claim])
tech_effects = gr.Textbox(label="Technical effects accepted by EPO:", show_label=True)
examples=gr.Examples(examples=[example_technical_effects], inputs=[tech_effects])
btn = gr.Button("Add description from dependent claims")
btn.click(fn=gen_dept_claims, inputs=[dependent_claim, tech_effects], outputs=out)
with gr.Row():
out.render()
with gr.Row():
#gr.ClearButton(components=[out])
clear_btn = gr.Button(value="Reset draft")
clear_btn.click(clear, inputs=[], outputs=[out, invention_title, current_sota, current_disadvantage, current_objective,
principal_claim, tech_adv, dependent_claim, tech_effects])
demo.queue().launch()