| importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.12.0/dist/tf.min.js') |
| importScripts('agent_sac.js') |
| importScripts('reply_buffer.js') |
|
|
| ;(async () => { |
| const DISABLED = false |
|
|
| const agent = new AgentSac({batchSize: 100, verbose: true}) |
| await agent.init() |
| await agent.checkpoint() |
| agent.actor.summary() |
| self.postMessage({weights: await Promise.all(agent.actor.getWeights().map(w => w.array()))}) |
|
|
| const rb = new ReplyBuffer(50000, ({ state: [telemetry, frameL, frameR], action, reward }) => { |
| frameL.dispose() |
| frameR.dispose() |
| telemetry.dispose() |
| action.dispose() |
| reward.dispose() |
| }) |
|
|
| |
| |
| |
| |
| |
| const job = async () => { |
| |
| if (DISABLED) return 99999 |
| if (rb.size < agent._batchSize*10) return 1000 |
| |
| const samples = rb.sample(agent._batchSize) |
| if (!samples.length) return 1000 |
| |
| const |
| framesL = [], |
| framesR = [], |
| telemetries = [], |
| actions = [], |
| rewards = [], |
| nextFramesL = [], |
| nextFramesR = [], |
| nextTelemetries = [] |
| |
| for (const { |
| state: [telemetry, frameL, frameR], |
| action, |
| reward, |
| nextState: [nextTelemetry, nextFrameL, nextFrameR] |
| } of samples) { |
| framesL.push(frameL) |
| framesR.push(frameR) |
| telemetries.push(telemetry) |
| actions.push(action) |
| rewards.push(reward) |
| nextFramesL.push(nextFrameL) |
| nextFramesR.push(nextFrameR) |
| nextTelemetries.push(nextTelemetry) |
| } |
| |
| tf.tidy(() => { |
| console.time('train') |
| agent.train({ |
| state: [tf.stack(telemetries), tf.stack(framesL), tf.stack(framesR)], |
| action: tf.stack(actions), |
| reward: tf.stack(rewards), |
| nextState: [tf.stack(nextTelemetries), tf.stack(nextFramesL), tf.stack(nextFramesR)] |
| }) |
| console.timeEnd('train') |
| }) |
|
|
| console.time('train postMessage') |
| self.postMessage({ |
| weights: await Promise.all(agent.actor.getWeights().map(w => w.array())) |
| }) |
| console.timeEnd('train postMessage') |
| |
| return 1 |
| } |
| |
| |
| |
| |
| const tick = async () => { |
| try { |
| setTimeout(tick, await job()) |
| } catch (e) { |
| console.error(e) |
| setTimeout(tick, 5000) |
| } |
| } |
| |
| setTimeout(tick, 1000) |
| |
| |
| |
| |
| |
| |
| |
| const decodeTransition = transition => { |
| let { id, state: [telemetry, frameL, frameR], action, reward, priority } = transition |
| |
| return tf.tidy(() => { |
| state = [ |
| tf.tensor1d(telemetry), |
| tf.tensor3d(frameL, agent._frameStackShape), |
| tf.tensor3d(frameR, agent._frameStackShape) |
| ] |
| action = tf.tensor1d(action) |
| reward = tf.tensor1d([reward]) |
| |
| return { id, state, action, reward, priority } |
| }) |
| } |
| |
| let i = 0 |
| self.addEventListener('message', async e => { |
| i++ |
|
|
| if (DISABLED) return |
| if (i%50 === 0) console.log('RBSIZE: ', rb.size) |
| |
| switch (e.data.action) { |
| case 'newTransition': |
| const transition = decodeTransition(e.data.transition) |
| rb.add(transition) |
|
|
| tf.tidy(()=> { |
| return |
| const { |
| state: [telemetry, frameL, frameR], |
| action, |
| } = transition; |
| const state = [tf.stack([telemetry]), tf.stack([frameL]), tf.stack([frameR])] |
| const q1TargValue = agent.q1Targ.predict([...state, tf.stack([action])], {batchSize: 1}) |
| const q2TargValue = agent.q2Targ.predict([...state, tf.stack([action])], {batchSize: 1}) |
| console.log('value', Math.min(q1TargValue.arraySync()[0][0], q2TargValue.arraySync()[0][0]).toFixed(5)) |
| }) |
|
|
|
|
| break |
| default: |
| console.warn('Unknown action') |
| break |
| } |
| |
| if (i % rb._limit === 0) |
| agent.checkpoint() |
| }) |
| })() |
|
|