Python|Python 还能实现哪些 AI 游戏?附上代码一起来一把( 三 )
然后定义整体的类变量DQN , 分别定义初始化和训练函数 , 其中网络层哪里主要就是神经网络层的调用 。 然后在训练函数里面记录当前动作和数据加载入优化器中达到模型训练效果 。
其中代码如下:
def __init__(self, options):self.options = optionsself.num_action = options['num_action']self.lr = options['lr']self.modelDir = options['modelDir']self.init_prob = options['init_prob']self.end_prob = options['end_prob']self.OBSERVE = options['OBSERVE']self.EXPLORE = options['EXPLORE']self.action_interval = options['action_interval']self.REPLAY_MEMORY = options['REPLAY_MEMORY']self.gamma = options['gamma']self.batch_size = options['batch_size']self.save_interval = options['save_interval']self.logfile = options['logfile']self.is_train = options['is_train']'''训练网络'''def train(self, session):x, q_values_ph = self.create_networkaction_now_ph = tf.placeholder('float', [None, self.num_action])target_q_values_ph = tf.placeholder('float', [None])# 计算lossloss = self.compute_loss(q_values_ph, action_now_ph, target_q_values_ph)# 优化目标trainStep = tf.train.AdamOptimizer(self.lr).minimize(loss)# 游戏gameState = PongGame# 用于记录数据dataDeque = deque# 当前的动作action_now = np.zeros(self.num_action)action_now[0] = 1# 初始化游戏状态x_now, reward, terminal = gameState.update_frame(action_now)x_now = cv2.cvtColor(cv2.resize(x_now, (80, 80)), cv2.COLOR_BGR2GRAY)_, x_now = cv2.threshold(x_now, 127, 255, cv2.THRESH_BINARY)scene_now = np.stack((x_now, )*4, axis=2)# 读取和保存checkpointsaver = tf.train.Saversession.run(tf.global_variables_initializer)checkpoint = tf.train.get_checkpoint_state(self.modelDir)if checkpoint and checkpoint.model_checkpoint_path:saver.restore(session, checkpoint.model_checkpoint_path)print('[INFO]: Load %s successfully...' % checkpoint.model_checkpoint_path)else:print('[INFO]: No weights found, start to train a new model...')prob = self.init_probnum_frame = 0logF = open(self.logfile, 'a')while True:q_values = q_values_ph.eval(feed_dict={x: [scene_now]})action_idx = get_action_idx(q_values=q_values,prob=prob,num_frame=num_frame,OBSERVE=self.OBSERVE,num_action=self.num_action)action_now = np.zeros(self.num_action)action_now[action_idx] = 1prob = down_prob(prob=prob,num_frame=num_frame,OBSERVE=self.OBSERVE,EXPLORE=self.EXPLORE,init_prob=self.init_prob,end_prob=self.end_prob)for _ in range(self.action_interval):scene_next, reward, terminal = self.next_frame(action_now=action_now,scene_now=scene_now, gameState=gameState)scene_now = scene_nextdataDeque.append((scene_now, action_now, reward, scene_next, terminal))if len(dataDeque) > self.REPLAY_MEMORY:dataDeque.popleftloss_now = Noneif (num_frame > self.OBSERVE):minibatch = random.sample(dataDeque, self.batch_size)scene_now_batch = [mb[0] for mb in minibatch]action_batch = [mb[1] for mb in minibatch]reward_batch = [mb[2] for mb in minibatch]scene_next_batch = [mb[3] for mb in minibatch]q_values_batch = q_values_ph.eval(feed_dict={x: scene_next_batch})target_q_values = self.compute_target_q_values(reward_batch, q_values_batch, minibatch)trainStep.run(feed_dict={target_q_values_ph: target_q_values,action_now_ph: action_batch,x: scene_now_batch})loss_now = session.run(loss, feed_dict={target_q_values_ph: target_q_values,action_now_ph: action_batch,x: scene_now_batch})num_frame += 1if num_frame % self.save_interval == 0:name = 'DQN_Pong'saver.save(session, os.path.join(self.modelDir, name), global_step=num_frame)log_content = ': %s, : %s, : %s, : %s, : %s, : %s' % (str(num_frame), str(prob), str(action_idx), str(reward), str(np.max(q_values)), str(loss_now))logF.write(log_content + '\n')print(log_content)logF.close'''创建网络'''def create_network(self):'''W_conv1 = self.init_weight_variable([9, 9, 4, 16])b_conv1 = self.init_bias_variable([16])W_conv2 = self.init_weight_variable([7, 7, 16, 32])b_conv2 = self.init_bias_variable([32])W_conv3 = self.init_weight_variable([5, 5, 32, 32])b_conv3 = self.init_bias_variable([32])W_conv4 = self.init_weight_variable([5, 5, 32, 64])b_conv4 = self.init_bias_variable([64])W_conv5 = self.init_weight_variable([3, 3, 64, 64])b_conv5 = self.init_bias_variable([64])'''W_conv1 = self.init_weight_variable([8, 8, 4, 32])b_conv1 = self.init_bias_variable([32])W_conv2 = self.init_weight_variable([4, 4, 32, 64])b_conv2 = self.init_bias_variable([64])W_conv3 = self.init_weight_variable([3, 3, 64, 64])b_conv3 = self.init_bias_variable([64])# 5 * 5 * 64 = 1600W_fc1 = self.init_weight_variable([1600, 512])b_fc1 = self.init_bias_variable([512])W_fc2 = self.init_weight_variable([512, self.num_action])b_fc2 = self.init_bias_variable([self.num_action])# input placeholderx = tf.placeholder('float', [None, 80, 80, 4])'''conv1 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(x, W_conv1, 4) + b_conv1, training=self.is_train, momentum=0.9))conv2 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv1, W_conv2, 2) + b_conv2, training=self.is_train, momentum=0.9))conv3 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv2, W_conv3, 2) + b_conv3, training=self.is_train, momentum=0.9))conv4 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv3, W_conv4, 1) + b_conv4, training=self.is_train, momentum=0.9))conv5 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv4, W_conv5, 1) + b_conv5, training=self.is_train, momentum=0.9))flatten = tf.reshape(conv5, [-1, 1600])'''conv1 = tf.nn.relu(self.conv2D(x, W_conv1, 4) + b_conv1)pool1 = self.maxpool(conv1)conv2 = tf.nn.relu(self.conv2D(pool1, W_conv2, 2) + b_conv2)conv3 = tf.nn.relu(self.conv2D(conv2, W_conv3, 1) + b_conv3)flatten = tf.reshape(conv3, [-1, 1600])fc1 = tf.nn.relu(tf.layers.batch_normalization(tf.matmul(flatten, W_fc1) + b_fc1, training=self.is_train, momentum=0.9))fc2 = tf.matmul(fc1, W_fc2) + b_fc2return x, fc2
推荐阅读
- 模型|REVIT技巧!如何创建能量模型,实现能量优化
- 懂懂笔记|主播也拼性价比 除了免坑费降抽成还能拼什么?
- 人工智能|人工智能上车就是聊天唱歌?TA还能给你有温度有情感的陪伴
- 数据|翼方健数解码隐私安全计算 实现数据“可用不可见”
- IOS系统|苹果免签封装如何实现?苹果免签封装会不会掉签?
- 小米手机|雷军再晒小米手机1代,压箱底整整9年,如今闲鱼还能卖888
- 中年|神奇!这款智能垃圾集成箱能自动开合还能紫外线消毒
- 行业互联网|最前线丨泰格医药通过港交所上市聆讯,或将实现“A+H”两地上市
- |为什么我店铺流量狂掉?淘宝竞争这么激烈还能不能做?
- 科学探索|为对抗美国,俄想和中国建月球基地,却被一张图扎心了,还能拿出什么?
