Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends, HTTPException | |
| from jinja2 import Environment | |
| from litellm.router import Router | |
| from dependencies import get_llm_router, get_prompt_templates | |
| from schemas import _ReqGroupingCategory, _ReqGroupingOutput, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse | |
| # Router for requirement processing | |
| router = APIRouter(tags=["requirement processing"]) | |
| def find_requirements_from_problem_description(req: ReqSearchRequest, llm_router: Router = Depends(get_llm_router)): | |
| """Finds the requirements that adress a given problem description from an extracted list""" | |
| requirements = req.requirements | |
| query = req.query | |
| requirements_text = "\n".join( | |
| [f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements]) | |
| resp_ai = llm_router.completion( | |
| model="gemini-v2", | |
| messages=[{"role": "user", "content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}], | |
| response_format=ReqSearchLLMResponse | |
| ) | |
| print("Answered") | |
| print(resp_ai.choices[0].message.content) | |
| out_llm = ReqSearchLLMResponse.model_validate_json( | |
| resp_ai.choices[0].message.content).selected | |
| if max(out_llm) > len(requirements) - 1: | |
| raise HTTPException( | |
| status_code=500, detail="LLM error : Generated a wrong index, please try again.") | |
| return ReqSearchResponse(requirements=[requirements[i] for i in out_llm]) | |
| async def categorize_reqs(params: ReqGroupingRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> ReqGroupingResponse: | |
| """Categorize the given service requirements into categories""" | |
| MAX_ATTEMPTS = 5 | |
| categories: list[_ReqGroupingCategory] = [] | |
| messages = [] | |
| # categorize the requirements using their indices | |
| req_prompt = await prompt_env.get_template("classify.txt").render_async(**{ | |
| "requirements": [rq.model_dump() for rq in params.requirements], | |
| "max_n_categories": params.max_n_categories, | |
| "response_schema": _ReqGroupingOutput.model_json_schema()}) | |
| # add system prompt with requirements | |
| messages.append({"role": "user", "content": req_prompt}) | |
| # ensure all requirements items are processed | |
| for attempt in range(MAX_ATTEMPTS): | |
| req_completion = await llm_router.acompletion(model="gemini-v2", messages=messages, response_format=_ReqGroupingOutput) | |
| output = _ReqGroupingOutput.model_validate_json( | |
| req_completion.choices[0].message.content) | |
| # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category | |
| valid_ids_universe = set(range(0, len(params.requirements))) | |
| assigned_ids = { | |
| req_id for cat in output.categories for req_id in cat.items} | |
| # keep only non-hallucinated, valid assigned ids | |
| valid_assigned_ids = assigned_ids.intersection(valid_ids_universe) | |
| # check for remaining requirements assigned to none of the categories | |
| unassigned_ids = valid_ids_universe - valid_assigned_ids | |
| if len(unassigned_ids) == 0: | |
| categories.extend(output.categories) | |
| break | |
| else: | |
| messages.append(req_completion.choices[0].message) | |
| messages.append( | |
| {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."}) | |
| if attempt == MAX_ATTEMPTS - 1: | |
| raise Exception("Failed to classify all requirements") | |
| # build the final category objects | |
| # remove the invalid (likely hallucinated) requirement IDs | |
| final_categories = [] | |
| for idx, cat in enumerate(output.categories): | |
| final_categories.append(ReqGroupingCategory( | |
| id=idx, | |
| title=cat.title, | |
| requirements=[params.requirements[i] | |
| for i in cat.items if i < len(params.requirements)] | |
| )) | |
| return ReqGroupingResponse(categories=final_categories) | |