File size: 974 Bytes
93a19af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from datetime import datetime, timezone
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema

CONFIG = {
    "max_new_tokens": 1000, 
    "temperature": 1, 
    "top_p": 0.8
}

class UserRequest(BaseModel):
    session_id: str
    prompt: str = None
    steering: bool = True
    coeff: float = -1.0
    max_new_tokens: int = Field(CONFIG["max_new_tokens"], le=CONFIG["max_new_tokens"])
    top_p: float = Field(CONFIG["top_p"], ge=0.0, le=1.0)
    temperature: float = Field(CONFIG["temperature"], ge=0.0, le=1.0)

    def generation_config(self):
        return {
            "max_new_tokens": self.max_new_tokens,
            "top_p": self.top_p,
            "temperature": self.temperature
        }


class SteeringOutput(UserRequest):
    max_new_tokens: SkipJsonSchema[int] = Field(exclude=True)
    output: str = None
    upvote: bool = None
    timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())