This is the function that uses a file to return some parameters
def load_network_params(agent_name: str, env_name: str,
network_root_folder: str = 'jax-models') -> flax.core.FrozenDict:
filePathString=network_root_folder+r"/"+agent_name+r"/"+env_name+r"/2"
fileNameStr=r"ckpt.199"
fileString=os.path.join(filePathString,fileNameStr)
with open(fileString, 'rb') as file:
# file processing irrelevant to error
})
return network_params
This is the way the function was subsequently used:
from dataset import *
generate_dataset(r"dqn", r"Breakout", 10_000, network_root_folder=r"C:/Users/jk5g19/Documents/Year3IP/scripts/jax")
Edit: Basically the generate_dataset calls load_network_parameter, I tried to cut down the amount of code shared, so that people would get less confused. This is the Minimum Viable product.
def generate_dataset(agent_name: str, env_name: str, dataset_size: int, num_envs: int = 20, epsilon: float = 0.1,
network_root_folder: str = 'jax-models') -> Tuple[onp.ndarray, onp.ndarray, onp.ndarray,
onp.ndarray, int]:
num_actions = gym.make(f'{env_name}NoFrameskip-v0').action_space.n
images_obs_dataset = onp.zeros((dataset_size, 84, 84, 4))
ram_obs_dataset = onp.zeros((dataset_size, 128))
q_values_dataset = onp.zeros((dataset_size, num_actions))
action_dataset = onp.zeros(dataset_size)
network_def, network_args = get_network_def(agent_name, num_actions)
network_params = load_network_params(agent_name, env_name, network_root_folder=network_root_folder)
return images_obs_dataset, ram_obs_dataset, q_values_dataset, action_dataset, episodes_run
When I output the filestring and copy and paste it into the file explorer the desired file can be accessed. I used raw strings for the path and experimented with double back slashes and forward slashes.
I also added a test.txt file in the path too, the file path did not work hence it rules out the file type causing issues with open().
Turns out that for my dataset_setup.py I was supposed to do this:
from dataset import *
generate_dataset(r"dqn", r"Breakout", 10_000, network_root_folder=r"/mnt/c/Users/jk5g19/Documents/Year3IP/scripts/jax")
This was because I was using WSL. Massive thank you to @gionni for helping me come to this conclusion.