main.py (3774B)
1 import random 2 import numpy as np 3 import matplotlib.pyplot as plt 4 import pickle 5 import time 6 from matplotlib import style 7 from Space_Invaders import Space_Invaders 8 from Actions import Actions 9 from MAMEToolkit.emulator import Emulator 10 from MAMEToolkit.emulator import Address 11 from MAMEToolkit.emulator import Action 12 from MAMEToolkit.emulator import list_actions 13 14 roms_path = "/home/john/media/downloads/Transmission/MAME 0.220 ROMs (split)/" 15 game_id = "invaders" 16 17 #print(list_actions(roms_path, game_id)) 18 19 # env = Space_Invaders("env1", roms_path) 20 # env.start() 21 22 23 def add_action_to_observation(observation, action): 24 return np.append([action], observation) 25 26 def initial_training(): 27 env = Space_Invaders("env1", roms_path) 28 env.start() 29 30 episode_rewards = [] 31 SHOOT_PENALTY = 1 32 DEATH_PENALTY = 50 33 KILL_REWARD = 500 34 MISS_PENALTY = 10 35 epsilon = 0.6 36 EPS_DECAY = 0.9998 37 SHOW_EVERY = 1000 # how often to play through env visually. 38 39 q_table = np.random.rand(5,4) 40 # start_q_table = "qtable-1608050271.pickle" 41 42 # with open(start_q_table, "rb") as f: 43 # q_table = pickle.load(f) 44 45 avg_reward = [] 46 47 48 LEARNING_RATE = 0.1 49 DISCOUNT = 0.95 50 games = 15 51 score = 0 52 reward = 0 53 iterations = 0 54 action = random.randint(0, 3) 55 score_list = [] 56 57 while games > 0: 58 iterations = iterations + 1 59 frames, rewards, game_over = env.step(action) 60 episode_reward = 0 61 old_score = score 62 state_index = 2 63 score = rewards["score"] 64 aliens = rewards["aliens"] 65 alive = rewards["alive"] 66 shot_status = rewards["shot_status"] 67 68 if alive == 0: 69 state_index = 1 70 elif shot_status != 12303 and shot_status != 12298 and shot_status != 12294: 71 state_index = 4 72 elif shot_status == 12303 or shot_status == 12298: 73 state_index = 0 74 elif shot_status == 12294: 75 state_index = 3 76 else: 77 state_index = 2 78 79 obs = (state_index, action) 80 81 if game_over: 82 games = games - 1 83 score_list.append(score) 84 env.new_game() 85 else: 86 if np.random.random() > epsilon: 87 action = np.argmax(q_table[obs]) 88 else: 89 action = np.random.randint(0, 3) 90 91 if state_index == 1: 92 reward = -DEATH_PENALTY 93 elif state_index == 4: 94 reward = MISS_PENALTY 95 elif state_index == 0: 96 reward = KILL_REWARD 97 elif state_index == 3: 98 reward = 0 99 else: 100 reward = -1 101 102 103 new_obs = (state_index, action) 104 max_future_q = np.max(q_table[new_obs]) 105 current_q = q_table[obs] 106 107 if reward == KILL_REWARD: 108 new_q = KILL_REWARD 109 else: 110 new_q = (1 - LEARNING_RATE) * current_q + LEARNING_RATE * (reward + DISCOUNT * max_future_q) 111 112 episode_reward += reward 113 episode_rewards.append(episode_reward) 114 epsilon *= EPS_DECAY 115 avg_reward.append(sum(episode_rewards) / len(episode_rewards) ) 116 117 moving_avg = np.convolve(episode_rewards, np.ones((iterations,))/iterations, mode='valid') 118 119 env.close() 120 121 print(q_table) 122 123 with open(f"qtable-{int(time.time())}.pickle", "wb") as f: 124 pickle.dump(q_table, f) 125 126 plt.plot([i for i in range(len(episode_rewards))], avg_reward) 127 plt.ylabel(f"Reward {SHOW_EVERY}ma") 128 plt.xlabel("episode #") 129 plt.show() 130 131 132 if __name__ == "__main__": 133 # # training_data = initial_training() 134 initial_training() 135 # # np.save('training', training_data) 136 # train_net() 137 # # test_net()