Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from sklearn.ensemble import RandomForestRegressor | |
import joblib # To save and load the trained model | |
from groq import Groq | |
from huggingface_hub import login | |
from dotenv import load_dotenv | |
# Load environment variables from .env file | |
load_dotenv() | |
# Authenticate with Hugging Face using the token stored in .env | |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
if HUGGINGFACE_TOKEN: | |
login(token=HUGGINGFACE_TOKEN) | |
else: | |
raise ValueError("Hugging Face token not found in environment variables") | |
# Initialize Groq client with the API key from .env | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
client = Groq(api_key=GROQ_API_KEY) | |
# Function to train or load the RandomForestRegressor model | |
def train_model(df): | |
# Encode the 'Crop' column as numeric (e.g., 'Rice' = 2, 'Wheat' = 1) | |
df['Crop'] = df['Crop'].map({'Wheat': 1, 'Rice': 2}) | |
# Preprocess the data (drop 'Yield' column for features) | |
X = df.drop(columns=["Yield"]) # Features | |
y = df["Yield"] # Target variable | |
# Train the Random Forest model | |
model = RandomForestRegressor(n_estimators=100, random_state=42) | |
model.fit(X, y) | |
# Save the model to disk | |
model_filename = '/content/crop_yield_model.pkl' | |
joblib.dump(model, model_filename) | |
return model_filename | |
# Load the trained model | |
def load_model(): | |
model_filename = '/content/crop_yield_model.pkl' | |
if os.path.exists(model_filename): | |
model = joblib.load(model_filename) | |
else: | |
raise Exception("Model not found. Please upload a valid dataset and train the model.") | |
return model | |
# Function to predict crop yield based on input features | |
def predict_yield(N, P, K, temperature, humidity, pH, rainfall, crop): | |
input_data = { | |
"Nitrogen": N, | |
"Phosphorus": P, | |
"Potassium": K, | |
"Temperature": temperature, | |
"Humidity": humidity, | |
"pH_Value": pH, | |
"Rainfall": rainfall, | |
"Crop": crop, | |
} | |
model = load_model() # Load the trained model | |
# Prepare input data as a DataFrame | |
input_df = pd.DataFrame([input_data]) | |
yield_prediction = model.predict(input_df) | |
return f"Predicted Yield: {yield_prediction[0]:.2f} kg/ha" | |
# Gradio Interface function to upload file and train the model | |
def upload_file(file): | |
# Load the dataset from the uploaded CSV file | |
try: | |
df = pd.read_csv(file.name) | |
# Check if required columns are present | |
required_columns = ["Nitrogen", "Phosphorus", "Potassium", "Temperature", "Humidity", "pH_Value", "Rainfall", "Crop", "Yield"] | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
return f"Error: Missing columns: {', '.join(missing_columns)}. Please upload a file with the required columns." | |
# Check for missing values in the required columns | |
if df[required_columns].isnull().sum().sum() > 0: | |
return "Error: Dataset contains missing values. Please clean the dataset before uploading." | |
# Check data types of the columns (they should be numeric except for the 'Crop' column) | |
for col in ["Nitrogen", "Phosphorus", "Potassium", "Temperature", "Humidity", "pH_Value", "Rainfall", "Yield"]: | |
if not pd.api.types.is_numeric_dtype(df[col]): | |
return f"Error: Column '{col}' contains non-numeric values. Please ensure all feature columns are numeric." | |
# Encode the 'Crop' column as numeric (e.g., 'Rice' = 2, 'Wheat' = 1) | |
df['Crop'] = df['Crop'].map({'Wheat': 1, 'Rice': 2}) | |
# If everything is fine, train the model | |
model_filename = train_model(df) | |
return f"Model trained successfully and saved as {model_filename}. You can now make predictions." | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Gradio Interface function for prediction | |
def interactive_interface(N, P, K, temperature, humidity, pH, rainfall, crop): | |
return predict_yield(N, P, K, temperature, humidity, pH, rainfall, crop) | |
# Additional Groq API test function | |
def test_groq_api(): | |
test_response = client.chat.completions.create( | |
messages=[{ | |
"role": "user", | |
"content": "Explain the importance of fast language models", | |
}], | |
model="llama3-8b-8192", | |
) | |
return test_response.choices[0].message.content | |
# Gradio setup for prediction interface | |
interface = gr.Interface( | |
fn=interactive_interface, | |
inputs=[ | |
gr.Number(label="Nitrogen (N)"), | |
gr.Number(label="Phosphorus (P)"), | |
gr.Number(label="Potassium (K)"), | |
gr.Number(label="Temperature (°C)"), | |
gr.Number(label="Humidity (%)"), | |
gr.Number(label="pH Value"), | |
gr.Number(label="Rainfall (mm)"), | |
gr.Textbox(label="Crop (1 = Wheat, 2 = Rice)"), | |
], | |
outputs=[gr.Textbox(label="Predicted Yield")], | |
title="Optimized Crop Yield Prediction", | |
description="Input soil and weather parameters to predict crop yield.", | |
) | |
# Gradio setup for uploading dataset | |
upload_interface = gr.Interface( | |
fn=upload_file, | |
inputs=gr.File(label="Upload your crop yield dataset (CSV)"), | |
outputs=[gr.Textbox(label="Status")], | |
title="Upload and Train Model", | |
description="Upload a CSV file with crop yield data to train the prediction model.", | |
) | |
# Test the Groq API and print the response | |
if __name__ == "__main__": | |
print("Groq API Test Response:", test_groq_api()) | |
upload_interface.launch() | |
interface.launch() | |