abhisheksan commited on
Commit
d8c637e
·
1 Parent(s): 1ccdaa7

Refactor GenerateRequest model; update field names to snake_case, add aliases, and improve validation methods

Browse files
Files changed (1) hide show
  1. main.py +14 -10
main.py CHANGED
@@ -42,9 +42,9 @@ class EmotionalTone(str, Enum):
42
  ROMANTIC = "romantic"
43
 
44
  class Length(str, Enum):
45
- SHORT = "short" # 100 words
46
- MEDIUM = "medium" # 200 words
47
- LONG = "long" # 300 words
48
 
49
  @dataclass
50
  class StyleConfig:
@@ -110,30 +110,34 @@ class StyleMapper:
110
  class GenerateRequest(BaseModel):
111
  prompt: str
112
  style: PoemStyle
113
- emotionalTone: EmotionalTone
114
- creativeStyle: float = Field(ge=0, le=100) # 0-100 slider
115
- languageVariety: float = Field(ge=0, le=1) # 0-1 slider
116
  length: Length
117
- wordRepetition: float = Field(ge=1, le=2) # 1-2 slider
118
 
119
- @validator('creativeStyle')
120
  def validate_creative_style(cls, v):
121
  if not 0 <= v <= 100:
122
  raise ValueError('creativeStyle must be between 0 and 100')
123
  return v
124
 
125
- @validator('languageVariety')
126
  def validate_language_variety(cls, v):
127
  if not 0 <= v <= 1:
128
  raise ValueError('languageVariety must be between 0 and 1')
129
  return v
130
 
131
- @validator('wordRepetition')
132
  def validate_word_repetition(cls, v):
133
  if not 1 <= v <= 2:
134
  raise ValueError('wordRepetition must be between 1 and 2')
135
  return v
136
 
 
 
 
 
137
  class ModelManager:
138
  def __init__(self):
139
  self.model = None
 
42
  ROMANTIC = "romantic"
43
 
44
  class Length(str, Enum):
45
+ SHORT = "short"
46
+ MEDIUM = "medium"
47
+ LONG = "long"
48
 
49
  @dataclass
50
  class StyleConfig:
 
110
  class GenerateRequest(BaseModel):
111
  prompt: str
112
  style: PoemStyle
113
+ emotional_tone: EmotionalTone = Field(alias="emotionalTone")
114
+ creative_style: float = Field(ge=0, le=100, alias="creativeStyle") # 0-100 slider
115
+ language_variety: float = Field(ge=0, le=1, alias="languageVariety") # 0-1 slider
116
  length: Length
117
+ word_repetition: float = Field(ge=1, le=2, alias="wordRepetition") # 1-2 slider
118
 
119
+ @validator('creative_style')
120
  def validate_creative_style(cls, v):
121
  if not 0 <= v <= 100:
122
  raise ValueError('creativeStyle must be between 0 and 100')
123
  return v
124
 
125
+ @validator('language_variety')
126
  def validate_language_variety(cls, v):
127
  if not 0 <= v <= 1:
128
  raise ValueError('languageVariety must be between 0 and 1')
129
  return v
130
 
131
+ @validator('word_repetition')
132
  def validate_word_repetition(cls, v):
133
  if not 1 <= v <= 2:
134
  raise ValueError('wordRepetition must be between 1 and 2')
135
  return v
136
 
137
+ class Config:
138
+ allow_population_by_field_name = True
139
+ alias_generator = None
140
+
141
  class ModelManager:
142
  def __init__(self):
143
  self.model = None