ositamiles commited on
Commit
be1fe7c
·
verified ·
1 Parent(s): c037455

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -78,30 +78,41 @@ def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
78
  return LayerNormalization(epsilon=1e-6)(x + res)
79
 
80
  # RL Environment
81
- class PricingEnv(gym.Env):
82
  def __init__(self, data):
83
  super(PricingEnv, self).__init__()
84
  self.data = data
85
  self.current_step = 0
 
86
  self.action_space = spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32)
87
  self.observation_space = spaces.Box(low=0, high=np.inf, shape=(6,), dtype=np.float32)
88
 
89
  def step(self, action):
90
  reward = self._get_reward(action)
91
  self.current_step += 1
92
- done = self.current_step >= len(self.data)
93
  obs = self._get_observation()
94
- return obs, reward, done, {}
95
 
96
- def reset(self):
 
97
  self.current_step = 0
98
- return self._get_observation()
99
 
100
  def _get_observation(self):
101
- obs = self.data.iloc[self.current_step][['demand_index', 'competitor_price', 'past_sales', 'genre_encoded', 'region_encoded']].values
102
- return np.append(obs, self.current_step)
 
 
 
 
 
 
103
 
104
  def _get_reward(self, action):
 
 
 
105
  price = action[0]
106
  actual_price = self.data.iloc[self.current_step]['price']
107
  return -abs(price - actual_price)
 
78
  return LayerNormalization(epsilon=1e-6)(x + res)
79
 
80
  # RL Environment
81
+ class PricingEnv(Env):
82
  def __init__(self, data):
83
  super(PricingEnv, self).__init__()
84
  self.data = data
85
  self.current_step = 0
86
+ self.max_steps = len(data) - 1
87
  self.action_space = spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32)
88
  self.observation_space = spaces.Box(low=0, high=np.inf, shape=(6,), dtype=np.float32)
89
 
90
  def step(self, action):
91
  reward = self._get_reward(action)
92
  self.current_step += 1
93
+ done = self.current_step >= self.max_steps
94
  obs = self._get_observation()
95
+ return obs, reward, done, False, {} # Added False for truncated flag
96
 
97
+ def reset(self, seed=None, options=None):
98
+ super().reset(seed=seed)
99
  self.current_step = 0
100
+ return self._get_observation(), {} # Return observation and info dict
101
 
102
  def _get_observation(self):
103
+ if self.current_step > self.max_steps:
104
+ # If we've gone past the end of the data, return the last valid observation
105
+ step = self.max_steps
106
+ else:
107
+ step = self.current_step
108
+
109
+ obs = self.data.iloc[step][['demand_index', 'competitor_price', 'past_sales', 'genre_encoded', 'region_encoded']].values
110
+ return np.append(obs, step)
111
 
112
  def _get_reward(self, action):
113
+ if self.current_step > self.max_steps:
114
+ return 0 # Or some other appropriate value for going out of bounds
115
+
116
  price = action[0]
117
  actual_price = self.data.iloc[self.current_step]['price']
118
  return -abs(price - actual_price)