Spaces:
Running
Running
Commit
·
d8c637e
1
Parent(s):
1ccdaa7
Refactor GenerateRequest model; update field names to snake_case, add aliases, and improve validation methods
Browse files
main.py
CHANGED
@@ -42,9 +42,9 @@ class EmotionalTone(str, Enum):
|
|
42 |
ROMANTIC = "romantic"
|
43 |
|
44 |
class Length(str, Enum):
|
45 |
-
SHORT = "short"
|
46 |
-
MEDIUM = "medium"
|
47 |
-
LONG = "long"
|
48 |
|
49 |
@dataclass
|
50 |
class StyleConfig:
|
@@ -110,30 +110,34 @@ class StyleMapper:
|
|
110 |
class GenerateRequest(BaseModel):
|
111 |
prompt: str
|
112 |
style: PoemStyle
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
length: Length
|
117 |
-
|
118 |
|
119 |
-
@validator('
|
120 |
def validate_creative_style(cls, v):
|
121 |
if not 0 <= v <= 100:
|
122 |
raise ValueError('creativeStyle must be between 0 and 100')
|
123 |
return v
|
124 |
|
125 |
-
@validator('
|
126 |
def validate_language_variety(cls, v):
|
127 |
if not 0 <= v <= 1:
|
128 |
raise ValueError('languageVariety must be between 0 and 1')
|
129 |
return v
|
130 |
|
131 |
-
@validator('
|
132 |
def validate_word_repetition(cls, v):
|
133 |
if not 1 <= v <= 2:
|
134 |
raise ValueError('wordRepetition must be between 1 and 2')
|
135 |
return v
|
136 |
|
|
|
|
|
|
|
|
|
137 |
class ModelManager:
|
138 |
def __init__(self):
|
139 |
self.model = None
|
|
|
42 |
ROMANTIC = "romantic"
|
43 |
|
44 |
class Length(str, Enum):
|
45 |
+
SHORT = "short"
|
46 |
+
MEDIUM = "medium"
|
47 |
+
LONG = "long"
|
48 |
|
49 |
@dataclass
|
50 |
class StyleConfig:
|
|
|
110 |
class GenerateRequest(BaseModel):
|
111 |
prompt: str
|
112 |
style: PoemStyle
|
113 |
+
emotional_tone: EmotionalTone = Field(alias="emotionalTone")
|
114 |
+
creative_style: float = Field(ge=0, le=100, alias="creativeStyle") # 0-100 slider
|
115 |
+
language_variety: float = Field(ge=0, le=1, alias="languageVariety") # 0-1 slider
|
116 |
length: Length
|
117 |
+
word_repetition: float = Field(ge=1, le=2, alias="wordRepetition") # 1-2 slider
|
118 |
|
119 |
+
@validator('creative_style')
|
120 |
def validate_creative_style(cls, v):
|
121 |
if not 0 <= v <= 100:
|
122 |
raise ValueError('creativeStyle must be between 0 and 100')
|
123 |
return v
|
124 |
|
125 |
+
@validator('language_variety')
|
126 |
def validate_language_variety(cls, v):
|
127 |
if not 0 <= v <= 1:
|
128 |
raise ValueError('languageVariety must be between 0 and 1')
|
129 |
return v
|
130 |
|
131 |
+
@validator('word_repetition')
|
132 |
def validate_word_repetition(cls, v):
|
133 |
if not 1 <= v <= 2:
|
134 |
raise ValueError('wordRepetition must be between 1 and 2')
|
135 |
return v
|
136 |
|
137 |
+
class Config:
|
138 |
+
allow_population_by_field_name = True
|
139 |
+
alias_generator = None
|
140 |
+
|
141 |
class ModelManager:
|
142 |
def __init__(self):
|
143 |
self.model = None
|