Spaces:
Running
Running
import os | |
import pandas as pd | |
import json | |
from datetime import datetime | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import uvicorn | |
import requests | |
app = FastAPI(title="HTS to HSN Classifier", version="1.0.0") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class ClassificationRequest(BaseModel): | |
hts_code_or_desc: str | |
class ClassificationResponse(BaseModel): | |
HSN_Code: str | None | |
HSN_Description: str | None | |
Confidence: str | |
Reasoning: str | |
class HuggingFaceInferenceClient: | |
def __init__(self, model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct", api_token: str = None): | |
self.model_name = model_name | |
self.api_token = api_token or os.getenv("HF_API_TOKEN") | |
if not self.api_token: | |
raise ValueError("Hugging Face API token not provided") | |
self.headers = { | |
"Authorization": f"Bearer {self.api_token}", | |
"Content-Type": "application/json" | |
} | |
# Fixed API URL - use the correct inference endpoint | |
self.api_url = f"https://api-inference.huggingface.co/models/{self.model_name}" | |
def invoke(self, prompt: str) -> str: | |
# Fixed payload structure for Hugging Face Inference API | |
payload = { | |
"inputs": prompt, | |
"parameters": { | |
"max_new_tokens": 500, | |
"temperature": 0.6, | |
"top_p": 0.95, | |
"return_full_text": False | |
} | |
} | |
try: | |
response = requests.post(self.api_url, json=payload, headers=self.headers, timeout=60) | |
response.raise_for_status() | |
data = response.json() | |
# Handle different response formats | |
if isinstance(data, list) and len(data) > 0: | |
if "generated_text" in data[0]: | |
return data[0]["generated_text"] | |
elif "text" in data[0]: | |
return data[0]["text"] | |
elif isinstance(data, dict): | |
if "generated_text" in data: | |
return data["generated_text"] | |
elif "text" in data: | |
return data["text"] | |
return str(data) | |
except requests.exceptions.RequestException as e: | |
print(f"API request failed: {e}") | |
raise Exception(f"Hugging Face API error: {e}") | |
except json.JSONDecodeError as e: | |
print(f"JSON decode error: {e}") | |
raise Exception(f"Invalid JSON response from API: {e}") | |
async def startup_event(): | |
try: | |
global vs_hts, vs_hsn, df_hts, df_hsn, llm_client, hts_code_col, hts_desc_col, hsn_code_col, hsn_desc_col | |
hts_path = "data/Htsdata.xlsx" | |
hsn_path = "data/HSN_SAC.xlsx" | |
cached_hts_vector_path = "data/faiss_hts_store" | |
cached_hsn_vector_path = "data/faiss_hsn_store" | |
print("Loading HSN data from:", hsn_path) | |
hts_code_col = "HTS Number" | |
hts_desc_col = "Description" | |
hsn_code_col = "HSN_CD" | |
hsn_desc_col = "HSN_Description" | |
df_hts = pd.read_excel(hts_path) | |
df_hts.columns = df_hts.columns.str.strip() | |
df_hsn = pd.read_excel(hsn_path) | |
df_hsn.columns = df_hsn.columns.str.strip() | |
df_hsn[hsn_code_col] = df_hsn[hsn_code_col].astype(str) | |
# Initialize with correct model name | |
llm_client = HuggingFaceInferenceClient(model_name="meta-llama/Meta-Llama-3-8B-Instruct") | |
print("✅ Application started successfully!") | |
except Exception as e: | |
print(f"❌ Startup error: {e}") | |
raise e | |
async def health_check(): | |
return {"status": "healthy", "timestamp": datetime.now().isoformat()} | |
def extract_structure(code: str): | |
code = "".join(filter(str.isdigit, str(code))) | |
return { | |
"chapter": code[:4] if len(code) >= 4 else None, | |
"heading": code[:6] if len(code) >= 6 else None, | |
"hsn8": code[:8] if len(code) >= 8 else code, | |
"full": code | |
} | |
def extract_json_from_text(text: str) -> dict: | |
"""Extract JSON from text response, handling various formats.""" | |
text = text.strip() | |
# Find JSON content between braces | |
start_idx = text.find('{') | |
end_idx = text.rfind('}') | |
if start_idx != -1 and end_idx != -1 and start_idx < end_idx: | |
json_str = text[start_idx:end_idx + 1] | |
try: | |
return json.loads(json_str) | |
except json.JSONDecodeError: | |
pass | |
# If JSON extraction fails, try to parse key-value pairs | |
try: | |
lines = text.split('\n') | |
result = {} | |
for line in lines: | |
if ':' in line: | |
key, value = line.split(':', 1) | |
key = key.strip().strip('"\'') | |
value = value.strip().strip('",\'') | |
if key in ['HSN_Code', 'HSN_Description', 'Confidence', 'Reasoning']: | |
result[key] = value | |
if len(result) >= 2: # At least some keys found | |
return result | |
except: | |
pass | |
return None | |
def map_hts_to_hsn(hts_code_or_desc: str): | |
reasoning_parts = [] | |
if hts_code_or_desc.isdigit(): | |
struct = extract_structure(hts_code_or_desc) | |
reasoning_parts.append(f"Input HTS code: {struct['full']}") | |
hts_match = df_hts[df_hts[hts_code_col].astype(str).str.startswith(struct["chapter"])] | |
hts_desc_list = hts_match[hts_desc_col].head(3).tolist() if not hts_match.empty else [] | |
hts_desc_text = "; ".join(hts_desc_list) if hts_desc_list else "No HTS description found." | |
reasoning_parts.append(f"HTS Chapter {struct['chapter']}: {hts_desc_text}") | |
hsn_match = df_hsn[df_hsn[hsn_code_col] == struct["hsn8"]] | |
if not hsn_match.empty: | |
best_match = hsn_match.iloc[0] | |
reasoning_parts.append(f"Exact 8-digit HSN {struct['hsn8']} found.") | |
return { | |
"HSN_Code": best_match[hsn_code_col], | |
"HSN_Description": best_match[hsn_desc_col], | |
"Confidence": "High", | |
"Reasoning": " ".join(reasoning_parts) | |
} | |
fallback_heading_match = df_hsn[df_hsn[hsn_code_col].str.startswith(struct["heading"])] | |
if not fallback_heading_match.empty: | |
fallback_heading = fallback_heading_match.iloc[0] | |
reasoning_parts.append(f"No exact 8-digit HSN. Fallback heading {struct['heading']} found.") | |
# Improved prompt with better formatting | |
system_prompt = "You are an expert in Indian HSN classification. Respond only with valid JSON containing the keys: HSN_Code, HSN_Description, Confidence, Reasoning." | |
user_prompt = f""" | |
Input HTS code: {struct['full']} | |
HTS Description: {hts_desc_text} | |
Fallback HSN heading: {fallback_heading[hsn_code_col]} - {fallback_heading[hsn_desc_col]} | |
Based on this information, provide the most appropriate 8-digit HSN code and description. | |
Required JSON format: | |
{{ | |
"HSN_Code": "XXXXXXXX", | |
"HSN_Description": "description here", | |
"Confidence": "High/Medium/Low", | |
"Reasoning": "explanation here" | |
}} | |
""" | |
full_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
try: | |
llm_response = llm_client.invoke(full_prompt).strip() | |
print(f"LLM Response: {llm_response}") # Debug logging | |
parsed_response = extract_json_from_text(llm_response) | |
if parsed_response and all(key in parsed_response for key in ["HSN_Code", "HSN_Description"]): | |
# Ensure all required keys are present with defaults | |
return { | |
"HSN_Code": parsed_response.get("HSN_Code"), | |
"HSN_Description": parsed_response.get("HSN_Description"), | |
"Confidence": parsed_response.get("Confidence", "Medium"), | |
"Reasoning": parsed_response.get("Reasoning", "LLM classification") | |
} | |
else: | |
print(f"Invalid LLM response format: {parsed_response}") | |
except Exception as e: | |
print(f"LLM failed: {e}") | |
# Fallback if LLM fails | |
return { | |
"HSN_Code": fallback_heading[hsn_code_col], | |
"HSN_Description": fallback_heading[hsn_desc_col], | |
"Confidence": "Medium", | |
"Reasoning": " ".join(reasoning_parts) + " LLM failed, using fallback 6-digit heading." | |
} | |
chapter_match = df_hsn[df_hsn[hsn_code_col].str.startswith(struct["chapter"][:4])] | |
if not chapter_match.empty: | |
best_match = chapter_match.iloc[0] | |
reasoning_parts.append(f"No heading match. Fallback to chapter {struct['chapter'][:4]}.") | |
return { | |
"HSN_Code": best_match[hsn_code_col], | |
"HSN_Description": best_match[hsn_desc_col], | |
"Confidence": "Low", | |
"Reasoning": " ".join(reasoning_parts) | |
} | |
return { | |
"HSN_Code": None, | |
"HSN_Description": None, | |
"Confidence": "Low", | |
"Reasoning": " ".join(reasoning_parts) + " No HSN match found." | |
} | |
else: | |
reasoning_parts.append("Input is description. Semantic search not implemented for Hugging Face deployment.") | |
return {"HSN_Code": None, "HSN_Description": None, "Confidence": "Low", | |
"Reasoning": "Description search not available in this deployment."} | |
async def classify_hts(request: ClassificationRequest): | |
try: | |
result = map_hts_to_hsn(request.hts_code_or_desc) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}") | |
async def root(): | |
return {"message": "HTS to HSN Classification API", "status": "running"} | |
if __name__ == "__main__": | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |