File size: 4,753 Bytes
b3081c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from typing import Type, Optional
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, START, END
import constants
from typing import TypedDict
# Define the State structure (similar to previous definition)
class State(TypedDict):
messages: list
output: Optional[BaseModel]
# Generic Pydantic model-based structured output extractor
class StructuredOutputExtractor:
def __init__(self, response_schema: Type[BaseModel]):
"""
Initializes the extractor for any given structured output model.
:param response_schema: Pydantic model class used for structured output extraction
"""
self.response_schema = response_schema
# Initialize language model (provider and API keys come from constants.py)
self.llm = self._choose_llm_provider(constants.CHOSEN_LLM_PROVIDER)
# Bind the model with structured output capability
self.structured_llm = self.llm.with_structured_output(response_schema)
# Build the graph for structured output
self._build_graph()
def _build_graph(self):
"""
Build the LangGraph computational graph for structured extraction.
"""
graph_builder = StateGraph(State)
# Add nodes and edges for structured output
graph_builder.add_node("extract", self._extract_structured_info)
graph_builder.add_edge(START, "extract")
graph_builder.add_edge("extract", END)
self.graph = graph_builder.compile()
def _extract_structured_info(self, state: dict):
"""
Extract structured information using the specified response model.
:param state: Current graph state
:return: Updated state with structured output
"""
query = state['messages'][-1].content
print(f"Processing query: {query}")
try:
# Extract details using the structured model
output = self.structured_llm.invoke(query)
# Return the structured response
return {"output": output}
except Exception as e:
print(f"Error during extraction: {e}")
return {"output": None}
def extract(self, query: str) -> Optional[BaseModel]:
"""
Public method to extract structured information.
:param query: Input query for structured output extraction
:return: Structured model object or None
"""
from langchain_core.messages import HumanMessage
result = self.graph.invoke({
"messages": [HumanMessage(content=query)]
})
# Return the structured model response, if available
result = result.get('output')
return result
def _choose_llm_provider(self, chosen_llm_provider):
"""Dynamically imports and selects the LLM provider based on configuration, and asks to install the library if it's missing."""
api_key = constants.llm_api_keys.get(chosen_llm_provider)
if chosen_llm_provider == 'openai':
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=constants.selected_llm_model.get('openai'), streaming=True, api_key=api_key)
elif chosen_llm_provider == 'ollama':
from langchain_ollama import ChatOllama
return ChatOllama(model=constants.selected_llm_model.get('ollama')) # streaming is enabled by default
elif chosen_llm_provider == 'groq':
from langchain_groq import ChatGroq
return ChatGroq(model=constants.selected_llm_model.get('groq'), streaming=True, api_key=api_key)
elif chosen_llm_provider == 'anthropic':
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=constants.selected_llm_model.get('anthropic'), streaming=True, api_key=api_key)
else:
raise ValueError(f"Unsupported LLM provider: {chosen_llm_provider}")
if __name__ == '__main__':
# Example Pydantic model (e.g., Movie)
class Movie(BaseModel):
title: str = Field(description="the title of the youtube video")
title_image: str = Field(description="highly detailed and descriptive image prompt for the Title")
items: list[str] = Field(description="top n number of requested items")
image_prompts: list[str] = Field(description="highly detailed and descriptive image prompts for each item ")
# Example usage with a generic structured extractor
extractor = StructuredOutputExtractor(response_schema=Movie)
query = "Top 5 Superheroes"
result = extractor.extract(query)
print(type(result))
if result:
print(result) |