hussain2010's picture
Update app.py
8398caf verified
raw
history blame
5.64 kB
import os
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
import joblib # To save and load the trained model
from groq import Groq
from huggingface_hub import login
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Authenticate with Hugging Face using the token stored in .env
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if HUGGINGFACE_TOKEN:
login(token=HUGGINGFACE_TOKEN)
else:
raise ValueError("Hugging Face token not found in environment variables")
# Initialize Groq client with the API key from .env
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
client = Groq(api_key=GROQ_API_KEY)
# Function to train or load the RandomForestRegressor model
def train_model(df):
# Encode the 'Crop' column as numeric (e.g., 'Rice' = 2, 'Wheat' = 1)
df['Crop'] = df['Crop'].map({'Wheat': 1, 'Rice': 2})
# Preprocess the data (drop 'Yield' column for features)
X = df.drop(columns=["Yield"]) # Features
y = df["Yield"] # Target variable
# Train the Random Forest model
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X, y)
# Save the model to disk
model_filename = '/content/crop_yield_model.pkl'
joblib.dump(model, model_filename)
return model_filename
# Load the trained model
def load_model():
model_filename = '/content/crop_yield_model.pkl'
if os.path.exists(model_filename):
model = joblib.load(model_filename)
else:
raise Exception("Model not found. Please upload a valid dataset and train the model.")
return model
# Function to predict crop yield based on input features
def predict_yield(N, P, K, temperature, humidity, pH, rainfall, crop):
input_data = {
"Nitrogen": N,
"Phosphorus": P,
"Potassium": K,
"Temperature": temperature,
"Humidity": humidity,
"pH_Value": pH,
"Rainfall": rainfall,
"Crop": crop,
}
model = load_model() # Load the trained model
# Prepare input data as a DataFrame
input_df = pd.DataFrame([input_data])
yield_prediction = model.predict(input_df)
return f"Predicted Yield: {yield_prediction[0]:.2f} kg/ha"
# Gradio Interface function to upload file and train the model
def upload_file(file):
# Load the dataset from the uploaded CSV file
try:
df = pd.read_csv(file.name)
# Check if required columns are present
required_columns = ["Nitrogen", "Phosphorus", "Potassium", "Temperature", "Humidity", "pH_Value", "Rainfall", "Crop", "Yield"]
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
return f"Error: Missing columns: {', '.join(missing_columns)}. Please upload a file with the required columns."
# Check for missing values in the required columns
if df[required_columns].isnull().sum().sum() > 0:
return "Error: Dataset contains missing values. Please clean the dataset before uploading."
# Check data types of the columns (they should be numeric except for the 'Crop' column)
for col in ["Nitrogen", "Phosphorus", "Potassium", "Temperature", "Humidity", "pH_Value", "Rainfall", "Yield"]:
if not pd.api.types.is_numeric_dtype(df[col]):
return f"Error: Column '{col}' contains non-numeric values. Please ensure all feature columns are numeric."
# Encode the 'Crop' column as numeric (e.g., 'Rice' = 2, 'Wheat' = 1)
df['Crop'] = df['Crop'].map({'Wheat': 1, 'Rice': 2})
# If everything is fine, train the model
model_filename = train_model(df)
return f"Model trained successfully and saved as {model_filename}. You can now make predictions."
except Exception as e:
return f"Error: {str(e)}"
# Gradio Interface function for prediction
def interactive_interface(N, P, K, temperature, humidity, pH, rainfall, crop):
return predict_yield(N, P, K, temperature, humidity, pH, rainfall, crop)
# Additional Groq API test function
def test_groq_api():
test_response = client.chat.completions.create(
messages=[{
"role": "user",
"content": "Explain the importance of fast language models",
}],
model="llama3-8b-8192",
)
return test_response.choices[0].message.content
# Gradio setup for prediction interface
interface = gr.Interface(
fn=interactive_interface,
inputs=[
gr.Number(label="Nitrogen (N)"),
gr.Number(label="Phosphorus (P)"),
gr.Number(label="Potassium (K)"),
gr.Number(label="Temperature (°C)"),
gr.Number(label="Humidity (%)"),
gr.Number(label="pH Value"),
gr.Number(label="Rainfall (mm)"),
gr.Textbox(label="Crop (1 = Wheat, 2 = Rice)"),
],
outputs=[gr.Textbox(label="Predicted Yield")],
title="Optimized Crop Yield Prediction",
description="Input soil and weather parameters to predict crop yield.",
)
# Gradio setup for uploading dataset
upload_interface = gr.Interface(
fn=upload_file,
inputs=gr.File(label="Upload your crop yield dataset (CSV)"),
outputs=[gr.Textbox(label="Status")],
title="Upload and Train Model",
description="Upload a CSV file with crop yield data to train the prediction model.",
)
# Test the Groq API and print the response
if __name__ == "__main__":
print("Groq API Test Response:", test_groq_api())
upload_interface.launch()
interface.launch()