Spaces:
Runtime error
Runtime error
Commit
·
9575e16
1
Parent(s):
96424ac
update demo
Browse files
texts/getting_my_agent_evaluated.md
CHANGED
@@ -93,7 +93,7 @@ class Agent(nn.Module):
|
|
93 |
agent = Agent(policy) # instantiate the agent
|
94 |
|
95 |
# A few tests to check if the agent is working
|
96 |
-
observations = torch.
|
97 |
actions = agent(observations)
|
98 |
actions = actions.numpy()[0]
|
99 |
assert env.action_space.contains(actions)
|
@@ -109,10 +109,9 @@ from huggingface_hub import metadata_save, HfApi
|
|
109 |
|
110 |
# Save model along with its card
|
111 |
metadata_save("model_card.md", {"tags": ["reinforcement-learning", env_id]})
|
112 |
-
dummy_input = torch.
|
113 |
agent = torch.jit.trace(agent.eval(), dummy_input)
|
114 |
-
agent = torch.jit.freeze(agent) # required for
|
115 |
-
agent = torch.jit.optimize_for_inference(agent)
|
116 |
torch.jit.save(agent, "agent.pt")
|
117 |
|
118 |
# Upload model and card to the 🤗 Hub
|
|
|
93 |
agent = Agent(policy) # instantiate the agent
|
94 |
|
95 |
# A few tests to check if the agent is working
|
96 |
+
observations = torch.randn(env.observation_space.shape).unsqueeze(0) # dummy batch of observations
|
97 |
actions = agent(observations)
|
98 |
actions = actions.numpy()[0]
|
99 |
assert env.action_space.contains(actions)
|
|
|
109 |
|
110 |
# Save model along with its card
|
111 |
metadata_save("model_card.md", {"tags": ["reinforcement-learning", env_id]})
|
112 |
+
dummy_input = torch.randn(env.observation_space.shape).unsqueeze(0) # dummy batch of observations
|
113 |
agent = torch.jit.trace(agent.eval(), dummy_input)
|
114 |
+
agent = torch.jit.freeze(agent) # required for the model not to depend on the training library
|
|
|
115 |
torch.jit.save(agent, "agent.pt")
|
116 |
|
117 |
# Upload model and card to the 🤗 Hub
|