CallmeKaito commited on
Commit
5540aa6
·
verified ·
1 Parent(s): b50a31e

Update tool.py

Browse files
Files changed (1) hide show
  1. tool.py +12 -16
tool.py CHANGED
@@ -1,19 +1,5 @@
1
  from smolagents import Tool
2
- from typing import Optional, Literal
3
-
4
- # Define valid diet options as a Literal type
5
- DietType = Literal[
6
- "none",
7
- "vegetarian",
8
- "vegan",
9
- "gluten free",
10
- "ketogenic",
11
- "paleo",
12
- "pescetarian",
13
- "whole30",
14
- "halal",
15
- "low fodmap"
16
- ]
17
 
18
  class SimpleTool(Tool):
19
  name = "CulinAI -- your AI Recipe Assistant"
@@ -48,7 +34,7 @@ class SimpleTool(Tool):
48
  }
49
  output_type = "string"
50
 
51
- def forward(self, ingredients: str, diet: DietType = "none", laziness: Optional[int] = 5) -> str:
52
  """
53
  Gets a recipe suggestion based on provided ingredients, dietary preference,
54
  and your laziness level (1=active chef, 10=super lazy). After finding a recipe, it
@@ -62,6 +48,16 @@ class SimpleTool(Tool):
62
  Returns:
63
  A string with detailed information about the recommended recipe.
64
  """
 
 
 
 
 
 
 
 
 
 
65
  import os
66
  import requests
67
 
 
1
  from smolagents import Tool
2
+ from typing import Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class SimpleTool(Tool):
5
  name = "CulinAI -- your AI Recipe Assistant"
 
34
  }
35
  output_type = "string"
36
 
37
+ def forward(self, ingredients: str, diet: str = "none", laziness: Optional[int] = 5) -> str:
38
  """
39
  Gets a recipe suggestion based on provided ingredients, dietary preference,
40
  and your laziness level (1=active chef, 10=super lazy). After finding a recipe, it
 
48
  Returns:
49
  A string with detailed information about the recommended recipe.
50
  """
51
+
52
+
53
+ # Validate diet option
54
+ valid_diets = {
55
+ "none", "vegetarian", "vegan", "gluten free", "ketogenic",
56
+ "paleo", "pescetarian", "whole30", "halal", "low fodmap"
57
+ }
58
+ if diet not in valid_diets:
59
+ return f"Invalid diet option. Please choose from: {', '.join(sorted(valid_diets))}"
60
+
61
  import os
62
  import requests
63