前言
本来想做强化学习小车,但是技术栈似乎飞了,所以还是一步一步来嘛。
Part 1 Code
import copy
import random
import json
import matplotlib.pyplot as plt
class OoxxMachine:
def __init__(self):
self.race = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
# 用于表示棋盘 0代表没下过 1 A玩家 2 B玩家
self.flag = "in_race"
# all situation is "in_race" or "a_win" or "b_win" or "all_lose"
# ----------------------------------------------------------------
self.learn_rate = 0.1
self.rand_poss = 0.05
self.net_values = {}
self.default_value = 0.5
def update_win(self):
"""
[[2,1,1][1,1,0][1,2,0]]
be like:
0 1 2
-------------
0 | x | o | o |
1 | x | x | |
2 | x | o | |
-------------
"""
if self.race[0][0] == self.race[0][1] == self.race[0][2]:
if self.race[0][0] == 1:
self.flag = "a_win"
elif self.race[0][0] == 2:
self.flag = "b_win"
else:
pass
if self.race[1][0] == self.race[1][1] == self.race[1][2]:
if self.race[1][0] == 1:
self.flag = "a_win"
elif self.race[1][0] == 2:
self.flag = "b_win"
else:
pass
if self.race[2][0] == self.race[2][1] == self.race[2][2]:
if self.race[2][0] == 1:
self.flag = "a_win"
elif self.race[2][0] == 2:
self.flag = "b_win"
else:
pass
if self.race[0][0] == self.race[1][0] == self.race[2][0]:
if self.race[0][0] == 1:
self.flag = "a_win"
elif self.race[0][0] == 2:
self.flag = "b_win"
else:
pass
if self.race[0][1] == self.race[1][1] == self.race[2][1]:
if self.race[0][1] == 1:
self.flag = "a_win"
elif self.race[0][1] == 2:
self.flag = "b_win"
else:
pass
if self.race[0][2] == self.race[1][2] == self.race[2][2]:
if self.race[0][2] == 1:
self.flag = "a_win"
elif self.race[0][2] == 2:
self.flag = "b_win"
else:
pass
if self.race[0][0] == self.race[1][1] == self.race[2][2]:
if self.race[2][2] == 1:
self.flag = "a_win"
elif self.race[2][2] == 2:
self.flag = "b_win"
else:
pass
if self.race[0][2] == self.race[1][1] == self.race[2][0]:
if self.race[0][2] == 1:
self.flag = "a_win"
elif self.race[0][2] == 2:
self.flag = "b_win"
else:
pass
all_chess = 0
for i in range(0, 3):
for j in range(0, 3):
if self.race[i][j] != 0:
all_chess += 1
# print(all_chess)
if all_chess == 8 and self.flag == "in_race":
self.flag = "all_lose"
return False
def reset(self):
self.race = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
self.flag = "in_race"
def do_once(self, racer: "str == a or b", location: list) -> str:
could_do = True
for i in range(0, 3):
if 0 in self.race[i]:
could_do = True
else:
pass
if not could_do:
self.flag = 'all_lose'
return "fin"
if racer == "a":
if self.race[location[0]][location[1]] == 0:
self.race[location[0]][location[1]] = 1
else:
raise ValueError("this location has been used")
if racer == "b":
if self.race[location[0]][location[1]] == 0:
self.race[location[0]][location[1]] = 2
else:
raise ValueError("this location has been used")
return "fin"
# 我对强化学习的理解还不够透彻
def refresh_net(self, now_race: list, next_race: list) -> bool:
# 传参: 赛场情况 需要更新的价值(在此赛场情况之前的价值) (本赛场)是否获胜
hash_value: int = hash(str(now_race))
hash_next: int = hash(str(next_race))
"""
# 如果给下死了就给Value 置于 0
if self.flag == 'b_win' or self.flag == 'all_lose':
self.net_values[hash_value] = 0
return False
# 更新下一次预期之获胜情况
copy_race = self.race
self.race = next_race
self.update_win()
if self.flag == 'a_win':
next_value = 1
self.net_values[hash_value] = 1
elif self.flag == 'b_win' or "all_lose":
next_value = 0
self.net_values[hash_value] = 0
self.race = copy_race
self.update_win()
"""
next_value = self.net_values[hash_next]
if hash_value not in self.net_values:
self.net_values[hash_value] = self.default_value
value = self.default_value
else:
value = self.net_values[hash_value]
value = value + (next_value - value) * self.learn_rate
self.net_values[hash_value] = value
return True
def save_net(self, filename='net.json'):
with open(filename, 'w') as file:
json.dump(self.net_values, file)
print(f"Net values saved to {filename}.")
def read_net(self, filename='net.json'):
with open(filename, 'r') as file:
self.net_values = json.load(file)
print(f"Net values loaded from {filename}.")
def random_player(self, player: str):
possible_location = []
race_copy = self.race
for i in range(0, 3):
for j in range(0, 3):
if race_copy[i][j] == 0:
possible_location.append([i, j])
if not possible_location:
self.flag = "all_lose"
return False
location = random.choice(possible_location)
self.do_once(player, location)
def start_train(self, epoch: int = 1000) -> bool:
self.reset()
a_win_times = 1
b_win_times = 1
win_rate = []
for times in range(1, epoch):
win_rate.append(a_win_times / (a_win_times + b_win_times))
plt.plot(win_rate)
# print(self.race)
if self.flag == "a_win":
a_win_times += 1
elif self.flag == "b_win":
b_win_times += 1
# print(times)
# print(self.net_values)
self.reset()
if random.randint(0, 1):
while self.flag == "in_race":
self.update_win()
if random.random() >= self.rand_poss:
next_races = []
for i in range(0, 3):
for j in range(0, 3):
if self.race[i][j] == 0:
races_copy = copy.deepcopy(self.race)
races_copy[i][j] = 1
next_races.append(races_copy)
else:
pass
values = []
for next_race in next_races:
copy_race = copy.deepcopy(self.race)
self.race = copy.deepcopy(next_race)
self.update_win()
if self.flag == 'a_win':
self.net_values[hash(str(next_race))] = 1
elif self.flag == 'b_win' or "all_lose":
self.net_values[hash(str(next_race))] = 0
self.race = copy.deepcopy(copy_race)
self.update_win()
next_hash = hash(str(next_race))
if next_hash not in self.net_values:
self.net_values[next_hash] = self.default_value
values.append(self.default_value)
else:
values.append(self.net_values[next_hash])
max_value = max(values)
max_indices = [index for index, value in enumerate(values) if value == max_value]
random_max_index = random.choice(max_indices)
next_race = next_races[random_max_index]
# print(next_races)
self.refresh_net(self.race, next_race)
self.race = next_race
# print(self.race)
else:
# print("random")
if self.random_player("a"):
pass
else:
break
self.random_player("b")
else:
while self.flag == "in_race":
self.update_win()
self.random_player("b")
if random.random() >= self.rand_poss:
next_races = []
for i in range(0, 3):
for j in range(0, 3):
if self.race[i][j] == 0:
races_copy = copy.deepcopy(self.race)
races_copy[i][j] = 1
next_races.append(races_copy)
else:
pass
values = []
for next_race in next_races:
copy_race = copy.deepcopy(self.race)
self.race = copy.deepcopy(next_race)
self.update_win()
if self.flag == 'a_win':
self.net_values[hash(str(next_race))] = 1
elif self.flag == 'b_win' or "all_lose":
self.net_values[hash(str(next_race))] = 0
self.race = copy.deepcopy(copy_race)
self.update_win()
next_hash = hash(str(next_race))
if next_hash not in self.net_values:
self.net_values[next_hash] = self.default_value
values.append(self.default_value)
else:
values.append(self.net_values[next_hash])
max_value = max(values)
max_indices = [index for index, value in enumerate(values) if value == max_value]
random_max_index = random.choice(max_indices)
next_race = next_races[random_max_index]
# print(next_races)
self.refresh_net(self.race, next_race)
self.race = next_race
# print(self.race)
else:
# print("random")
if self.random_player("a"):
pass
else:
break
# do the race once at here
print(f"a wins {str(a_win_times)} b wins {str(b_win_times)}")
print(f"A的胜率是{str(a_win_times / (a_win_times + b_win_times))}")
plt.show()
return True
if __name__ == "__main__":
aa = OoxxMachine()
# aa.read_net()
aa.start_train(10000)
# aa.save_net()