Spaces:
Sleeping
Sleeping
Integrating the falconLLM
Browse files- app.py +18 -23
- tester.py +6 -42
- trainer.py +11 -10
app.py
CHANGED
|
@@ -5,6 +5,8 @@ import streamlit as st
|
|
| 5 |
import os
|
| 6 |
from trainer import train
|
| 7 |
from tester import test
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def main():
|
|
@@ -24,32 +26,25 @@ def main():
|
|
| 24 |
st.sidebar.write(f"Jammer Type: {jammer_type}")
|
| 25 |
st.sidebar.write(f"Channel Switching Cost: {channel_switching_cost}")
|
| 26 |
|
| 27 |
-
|
| 28 |
-
test_button = st.sidebar.button('Test')
|
| 29 |
-
|
| 30 |
-
if train_button or test_button:
|
| 31 |
-
agent_name = f'DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
| 32 |
-
if os.path.exists(agent_name):
|
| 33 |
-
if train_button:
|
| 34 |
-
st.warning("Agent has been trained already! Do you want to retrain?")
|
| 35 |
-
retrain = st.sidebar.button('Yes')
|
| 36 |
-
if retrain:
|
| 37 |
-
perform_training(jammer_type, channel_switching_cost)
|
| 38 |
-
elif test_button:
|
| 39 |
-
perform_testing(jammer_type, channel_switching_cost)
|
| 40 |
-
else:
|
| 41 |
-
if train_button:
|
| 42 |
-
perform_training(jammer_type, channel_switching_cost)
|
| 43 |
-
elif test_button:
|
| 44 |
-
st.warning("Agent has not been trained yet. Click Train First!!!")
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
def perform_training(jammer_type, channel_switching_cost):
|
| 48 |
-
train(jammer_type, channel_switching_cost)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
if __name__ == "__main__":
|
|
|
|
| 5 |
import os
|
| 6 |
from trainer import train
|
| 7 |
from tester import test
|
| 8 |
+
import transformers
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
|
| 11 |
|
| 12 |
def main():
|
|
|
|
| 26 |
st.sidebar.write(f"Jammer Type: {jammer_type}")
|
| 27 |
st.sidebar.write(f"Channel Switching Cost: {channel_switching_cost}")
|
| 28 |
|
| 29 |
+
start_button = st.sidebar.button('Start')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
if start_button:
|
| 32 |
+
agent = perform_training(jammer_type, channel_switching_cost)
|
| 33 |
+
test(agent, jammer_type, channel_switching_cost)
|
| 34 |
|
| 35 |
def perform_training(jammer_type, channel_switching_cost):
|
| 36 |
+
agent = train(jammer_type, channel_switching_cost)
|
| 37 |
+
model_name = "tiiuae/falcon-7b-instruct"
|
| 38 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 40 |
+
pipeline = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=100, temperature=0.7)
|
| 41 |
+
text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
|
| 42 |
+
st.write(text)
|
| 43 |
+
return agent
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def perform_testing(agent, jammer_type, channel_switching_cost):
|
| 47 |
+
test(agent, jammer_type, channel_switching_cost)
|
| 48 |
|
| 49 |
|
| 50 |
if __name__ == "__main__":
|
tester.py
CHANGED
|
@@ -9,26 +9,17 @@ from DDQN import DoubleDeepQNetwork
|
|
| 9 |
from antiJamEnv import AntiJamEnv
|
| 10 |
|
| 11 |
|
| 12 |
-
def test(jammer_type, channel_switching_cost):
|
| 13 |
env = AntiJamEnv(jammer_type, channel_switching_cost)
|
| 14 |
ob_space = env.observation_space
|
| 15 |
ac_space = env.action_space
|
| 16 |
|
| 17 |
s_size = ob_space.shape[0]
|
| 18 |
a_size = ac_space.n
|
| 19 |
-
max_env_steps =
|
| 20 |
-
TEST_Episodes =
|
| 21 |
env._max_episode_steps = max_env_steps
|
| 22 |
-
|
| 23 |
-
epsilon = 1.0 # exploration rate
|
| 24 |
-
epsilon_min = 0.01
|
| 25 |
-
epsilon_decay = 0.999
|
| 26 |
-
discount_rate = 0.95
|
| 27 |
-
lr = 0.001
|
| 28 |
-
|
| 29 |
-
agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
| 30 |
-
DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
|
| 31 |
-
DDQN_agent.model = DDQN_agent.load_saved_model(agentName)
|
| 32 |
rewards = [] # Store rewards for graphing
|
| 33 |
epsilons = [] # Store the Explore/Exploit
|
| 34 |
|
|
@@ -47,35 +38,8 @@ def test(jammer_type, channel_switching_cost):
|
|
| 47 |
break
|
| 48 |
next_state = np.reshape(next_state, [1, s_size])
|
| 49 |
tot_rewards += reward
|
|
|
|
|
|
|
| 50 |
# DON'T STORE ANYTHING DURING TESTING
|
| 51 |
state = next_state
|
| 52 |
|
| 53 |
-
# Plotting
|
| 54 |
-
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
|
| 55 |
-
|
| 56 |
-
# Create a new Streamlit figure
|
| 57 |
-
fig = plt.figure()
|
| 58 |
-
plt.plot(rewards, label='Rewards')
|
| 59 |
-
plt.plot(rolling_average, color='black', label='Rolling Average')
|
| 60 |
-
plt.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
|
| 61 |
-
eps_graph = [100 * x for x in epsilons]
|
| 62 |
-
plt.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
|
| 63 |
-
plt.xlabel('Episodes')
|
| 64 |
-
plt.ylabel('Rewards')
|
| 65 |
-
plt.title(f'Testing Rewards - {jammer_type}, CSC: {channel_switching_cost}')
|
| 66 |
-
plt.legend()
|
| 67 |
-
|
| 68 |
-
# Display the Streamlit figure using streamlit.pyplot
|
| 69 |
-
st.set_option('deprecation.showPyplotGlobalUse', False)
|
| 70 |
-
st.pyplot(fig)
|
| 71 |
-
|
| 72 |
-
# Save the figure
|
| 73 |
-
plot_name = f'./data/test_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
| 74 |
-
plt.savefig(plot_name, bbox_inches='tight')
|
| 75 |
-
plt.close(fig) # Close the figure to release resources
|
| 76 |
-
|
| 77 |
-
# Save Results
|
| 78 |
-
# Rewards
|
| 79 |
-
fileName = f'./data/test_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
|
| 80 |
-
with open(fileName, 'w') as f:
|
| 81 |
-
json.dump(rewards, f)
|
|
|
|
| 9 |
from antiJamEnv import AntiJamEnv
|
| 10 |
|
| 11 |
|
| 12 |
+
def test(agent, jammer_type, channel_switching_cost):
|
| 13 |
env = AntiJamEnv(jammer_type, channel_switching_cost)
|
| 14 |
ob_space = env.observation_space
|
| 15 |
ac_space = env.action_space
|
| 16 |
|
| 17 |
s_size = ob_space.shape[0]
|
| 18 |
a_size = ac_space.n
|
| 19 |
+
max_env_steps = 3
|
| 20 |
+
TEST_Episodes = 1
|
| 21 |
env._max_episode_steps = max_env_steps
|
| 22 |
+
DDQN_agent = agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
rewards = [] # Store rewards for graphing
|
| 24 |
epsilons = [] # Store the Explore/Exploit
|
| 25 |
|
|
|
|
| 38 |
break
|
| 39 |
next_state = np.reshape(next_state, [1, s_size])
|
| 40 |
tot_rewards += reward
|
| 41 |
+
|
| 42 |
+
st.write(f"The state is: {state}, action taken is: {action}, obtained reward is: {reward}")
|
| 43 |
# DON'T STORE ANYTHING DURING TESTING
|
| 44 |
state = next_state
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.py
CHANGED
|
@@ -21,7 +21,7 @@ def train(jammer_type, channel_switching_cost):
|
|
| 21 |
s_size = ob_space.shape[0]
|
| 22 |
a_size = ac_space.n
|
| 23 |
max_env_steps = 100
|
| 24 |
-
TRAIN_Episodes =
|
| 25 |
env._max_episode_steps = max_env_steps
|
| 26 |
|
| 27 |
epsilon = 1.0
|
|
@@ -85,16 +85,17 @@ def train(jammer_type, channel_switching_cost):
|
|
| 85 |
st.pyplot(fig)
|
| 86 |
|
| 87 |
# Save the figure
|
| 88 |
-
plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
| 89 |
-
plt.savefig(plot_name, bbox_inches='tight')
|
| 90 |
plt.close(fig) # Close the figure to release resources
|
| 91 |
|
| 92 |
# Save Results
|
| 93 |
# Rewards
|
| 94 |
-
fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
|
| 95 |
-
with open(fileName, 'w') as f:
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
# Save the agent as a SavedAgent.
|
| 99 |
-
agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
| 100 |
-
DDQN_agent.save_model(agentName)
|
|
|
|
|
|
| 21 |
s_size = ob_space.shape[0]
|
| 22 |
a_size = ac_space.n
|
| 23 |
max_env_steps = 100
|
| 24 |
+
TRAIN_Episodes = 25
|
| 25 |
env._max_episode_steps = max_env_steps
|
| 26 |
|
| 27 |
epsilon = 1.0
|
|
|
|
| 85 |
st.pyplot(fig)
|
| 86 |
|
| 87 |
# Save the figure
|
| 88 |
+
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
| 89 |
+
# plt.savefig(plot_name, bbox_inches='tight')
|
| 90 |
plt.close(fig) # Close the figure to release resources
|
| 91 |
|
| 92 |
# Save Results
|
| 93 |
# Rewards
|
| 94 |
+
# fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
|
| 95 |
+
# with open(fileName, 'w') as f:
|
| 96 |
+
# json.dump(rewards, f)
|
| 97 |
+
#
|
| 98 |
+
# # Save the agent as a SavedAgent.
|
| 99 |
+
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
| 100 |
+
# DDQN_agent.save_model(agentName)
|
| 101 |
+
return DDQN_agent
|