Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import pipeline | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| # Load zero-shot classifier | |
| classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
| # Candidate labels | |
| labels = ["high risk", "medium risk", "low risk"] | |
| def classify_clauses(text_input): | |
| # Split input into clauses | |
| clauses = [clause.strip() for clause in text_input.strip().split('\n') if clause.strip()] | |
| scores = [] | |
| for clause in clauses: | |
| result = classifier(clause, labels) | |
| scores.append(result['scores']) | |
| scores_array = np.array(scores) | |
| # Plot heatmap | |
| plt.figure(figsize=(10, 6)) | |
| sns.heatmap( | |
| scores_array, | |
| annot=True, | |
| xticklabels=labels, | |
| yticklabels=[f"Clause {i+1}" for i in range(len(clauses))], | |
| cmap="Reds" | |
| ) | |
| plt.title("Contract Clause Risk Heatmap") | |
| plt.xlabel("Risk Level") | |
| plt.ylabel("Clauses") | |
| plt.tight_layout() | |
| # Save and return the plot | |
| plot_path = "heatmap.png" | |
| plt.savefig(plot_path) | |
| plt.close() | |
| return plot_path | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=classify_clauses, | |
| inputs=gr.Textbox(lines=10, label="Enter Contract Clauses (one per line)"), | |
| outputs=gr.Image(type="filepath"), | |
| title="Contract Risk Heatmap Generator", | |
| description="Enter clauses line by line. Uses zero-shot classification to visualize risk levels." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |