Spaces:
Running
Running
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js') | |
importScripts('https://ai-creature.github.io/agent_sac.js') | |
importScripts('https://ai-creature.github.io/reply_buffer.js') | |
;(async () => { | |
const DISABLED = false | |
const agent = new AgentSac({batchSize: 100, verbose: true}) | |
await agent.init() | |
await agent.checkpoint() // overwrite | |
agent.actor.summary() | |
self.postMessage({weights: await Promise.all(agent.actor.getWeights().map(w => w.array()))}) // syncronize | |
const rb = new ReplyBuffer(50000, ({ state: [telemetry, frameL, frameR], action, reward }) => { | |
frameL.dispose() | |
frameR.dispose() | |
telemetry.dispose() | |
action.dispose() | |
reward.dispose() | |
}) | |
/** | |
* Worker. | |
* | |
* @returns delay in ms to get ready for the next job | |
*/ | |
const job = async () => { | |
// throw 'disabled' | |
if (DISABLED) return 99999 | |
if (rb.size < agent._batchSize*10) return 1000 | |
const samples = rb.sample(agent._batchSize) // time fast | |
if (!samples.length) return 1000 | |
const | |
framesL = [], | |
framesR = [], | |
telemetries = [], | |
actions = [], | |
rewards = [], | |
nextFramesL = [], | |
nextFramesR = [], | |
nextTelemetries = [] | |
for (const { | |
state: [telemetry, frameL, frameR], | |
action, | |
reward, | |
nextState: [nextTelemetry, nextFrameL, nextFrameR] | |
} of samples) { | |
framesL.push(frameL) | |
framesR.push(frameR) | |
telemetries.push(telemetry) | |
actions.push(action) | |
rewards.push(reward) | |
nextFramesL.push(nextFrameL) | |
nextFramesR.push(nextFrameR) | |
nextTelemetries.push(nextTelemetry) | |
} | |
tf.tidy(() => { | |
console.time('train') | |
agent.train({ | |
state: [tf.stack(telemetries), tf.stack(framesL), tf.stack(framesR)], | |
action: tf.stack(actions), | |
reward: tf.stack(rewards), | |
nextState: [tf.stack(nextTelemetries), tf.stack(nextFramesL), tf.stack(nextFramesR)] | |
}) | |
console.timeEnd('train') | |
}) | |
console.time('train postMessage') | |
self.postMessage({ | |
weights: await Promise.all(agent.actor.getWeights().map(w => w.array())) | |
}) | |
console.timeEnd('train postMessage') | |
return 1 | |
} | |
/** | |
* Executes job. | |
*/ | |
const tick = async () => { | |
try { | |
setTimeout(tick, await job()) | |
} catch (e) { | |
console.error(e) | |
setTimeout(tick, 5000) // show must go on (҂◡_◡) ᕤ | |
} | |
} | |
setTimeout(tick, 1000) | |
/** | |
* Decode transition from the main thread. | |
* | |
* @param {{ id, state, action, reward }} transition | |
* @returns | |
*/ | |
const decodeTransition = transition => { | |
let { id, state: [telemetry, frameL, frameR], action, reward, priority } = transition | |
return tf.tidy(() => { | |
state = [ | |
tf.tensor1d(telemetry), | |
tf.tensor3d(frameL, agent._frameStackShape), | |
tf.tensor3d(frameR, agent._frameStackShape) | |
] | |
action = tf.tensor1d(action) | |
reward = tf.tensor1d([reward]) | |
return { id, state, action, reward, priority } | |
}) | |
} | |
let i = 0 | |
self.addEventListener('message', async e => { | |
i++ | |
if (DISABLED) return | |
if (i%50 === 0) console.log('RBSIZE: ', rb.size) | |
switch (e.data.action) { | |
case 'newTransition': | |
const transition = decodeTransition(e.data.transition) | |
rb.add(transition) | |
tf.tidy(()=> { | |
return | |
const { | |
state: [telemetry, frameL, frameR], | |
action, | |
} = transition; | |
const state = [tf.stack([telemetry]), tf.stack([frameL]), tf.stack([frameR])] | |
const q1TargValue = agent.q1Targ.predict([...state, tf.stack([action])], {batchSize: 1}) | |
const q2TargValue = agent.q2Targ.predict([...state, tf.stack([action])], {batchSize: 1}) | |
console.log('value', Math.min(q1TargValue.arraySync()[0][0], q2TargValue.arraySync()[0][0]).toFixed(5)) | |
}) | |
break | |
default: | |
console.warn('Unknown action') | |
break | |
} | |
if (i % rb._limit === 0) | |
agent.checkpoint() // timer ~ 500ms, don't await intentionally | |
}) | |
})() | |