diff --git a/bindsnet/analysis/dotTrace_plotter.py b/bindsnet/analysis/dotTrace_plotter.py new file mode 100644 index 00000000..d9e552b1 --- /dev/null +++ b/bindsnet/analysis/dotTrace_plotter.py @@ -0,0 +1,151 @@ +import numpy as np +import glob +import sys + +import matplotlib.pyplot as plt + +# Define grid dimensions globally +ROWS = 28 +COLS = 28 + + +def plotGrids(gridData): + if gridData.shape[0] % ROWS != 0 or gridData.shape[1] != COLS: + raise ("Incompatible grid dimensionality: check data and assumed dimensions.") + + grids = gridData.shape[0] // ROWS + + print("Reshaping into", grids, "grids of shape (", ROWS, ",", COLS, ")") + gridData = gridData.reshape((grids, ROWS, COLS)) + + plotAnotherRange = True + + while plotAnotherRange: + start = -1 + end = 1 + print("Select the range of iterations to generate grid plots from.") + print("0 means plot all iterations.") + while (start < 0 or grids - 1 < start) or (end < 1 or grids < end): + start = int(input("Start: ")) + + # If start is set to zero, plot everything. + if start == 0: + continue + + end = int(input("End: ")) + + if start == 0: + print("\nPlotting whole shebang!") + else: + print("\nPlotting range from iteration", start, "to", end) + + # Plotting time! + plt.figure() + plt.ion() + plt.imshow(gridData[start], cmap="hot", interpolation="nearest") + plt.colorbar() + plt.pause(0.001) # Pause so that that GUI can do its thing. + for g in gridData[start + 1 : end]: + plt.imshow(g, cmap="hot", interpolation="nearest") + plt.pause(0.001) # Pause so that that GUI can do its thing. + + plotAnotherRange = str.lower(input("Plot another range? (y/n): ")) == "y" + + +def plotRewards(rewData, fname): + cumRewards = np.cumsum(rewData) + tsteps = np.array(range(len(cumRewards))) + + # Plotting time! + plt.figure() + plt.plot(tsteps, cumRewards) + plt.xlabel("Timesteps") + plt.ylabel("Cumulative Reward") + plt.title("Cumulative Reward by Iteration") + plt.savefig(fname[0:-4] + ".png", dpi=200) + plt.pause(0.001) # Pause so that that GUI can do its thing. + + +def plotPerformance(perfData, fname): + + # Set bins to a tenth of the episodes, rounded up. + binIdx = np.array(range(len(perfData))) // 10 + bins = np.bincount(binIdx, perfData).astype("uint32") + + # Plotting time! + plt.figure() + plt.bar(np.unique(binIdx), bins, color="seagreen") + plt.xlabel("Episode Bins") + plt.ylabel("Number of Intercepts") + plt.title("Interception Performance Across Episodes") + plt.savefig(fname[0:-4] + ".png", dpi=200) + plt.pause(0.001) # Pause so that that GUI can do its thing. + + +def main(): + """ + File types: + + 0) grid - the 2D matrix observation + 1) reward - list of rewards per iteration + 2) performance - list of performance values + """ + fileType = 0 # default to grid + + # By default, we'll search the examples directory, but tweak as needed. + files = glob.glob("../../examples/*/out/*csv") + + if len(files) == 0: + print("Could not find any csv files. Exiting...") + sys.exit() + + plotAnotherFile = True + + while plotAnotherFile: + print("Select the file to generate grid plots from.") + for i, f in enumerate(files): + print(str(i), "-", f) + + # Select the intended file. + sel = -1 + while sel < 0 or len(files) < sel: + sel = int(input("\nFile selection: ")) + + fileToPlot = files[sel] + + # Check file type + if 0 < fileToPlot.find("grid"): + print("\nFound 'grid' in name: assuming a grid file type.") + fileType = 0 + elif 0 < fileToPlot.find("rew"): + print("\nFound 'rew' in name: assuming a reward file type.") + fileType = 1 + + elif 0 < fileToPlot.find("perf"): + print("\nFound 'perf' in name: assuming a performance file type.") + fileType = 2 + else: + print("\nUnknown file type. Which type are we plotting?") + print("\n0) grid\n1) reward\n2) performance") + fileType = -1 + while fileType < 0 or 2 < fileType: + fileType = int(input("\nFile type: ")) + + print("\nPlotting: ", fileToPlot) + data = np.genfromtxt(fileToPlot, delimiter=",") + + # Plot by file type + if fileType == 0: + plotGrids(data) + elif fileType == 1: + plotRewards(data, fileToPlot) + elif fileType == 2: + plotPerformance(data, fileToPlot) + else: + print("ERROR: Unknown file type") + + plotAnotherFile = str.lower(input("Plot another file? (y/n): ")) == "y" + + +if __name__ == "__main__": + main() diff --git a/bindsnet/encoding/encodings.py b/bindsnet/encoding/encodings.py index e4180266..8fc21bed 100644 --- a/bindsnet/encoding/encodings.py +++ b/bindsnet/encoding/encodings.py @@ -15,7 +15,7 @@ def single( """ Generates timing based single-spike encoding. Spike occurs earlier if the intensity of the input feature is higher. Features whose value is lower than - threshold is remain silent. + the threshold remain silent. :param datum: Tensor of shape ``[n_1, ..., n_k]``. :param time: Length of the input and output. diff --git a/bindsnet/environment/README.md b/bindsnet/environment/README.md new file mode 100644 index 00000000..5b1d32a7 --- /dev/null +++ b/bindsnet/environment/README.md @@ -0,0 +1,59 @@ +## Dot Simulator + +### Overview + +This simulator lets us generate dots and make them move in a configurable 2D space, providing a visual to a neural network for training in experiments. + +Specifically, this generates a grid for each timestep, where a specified number of points have values of 1 with fading tails ("decay"), designating the current positions and movements of their corresponding dots. All other points are set to 0. From timestep to timestep, the dots either remain where they are or move one space. + +The 2D observation of the current state is provided every step, as well as the reward, completion flag, and sucessful interception flag. It may be helpful to scale the grid values when encoding them as spike trains. + +The intended objective is to train a network to use its "network dot" to trace or intercept a moving "target" dot. But this simulator is designed to easily adapt to multiple kinds of experiments. + + +### Dot Movement + +By default, there is a single "target" dot that moves in a random direction every timestep (or it can stay still, which can be disabled), and as it moves, it leaves a tunable "decay" in the form of a fading tail. The simulator supports four directions of movement by default (up/down/left/right) by default, as well as remaining still, but the diag parameters allows diagonal movement for more complexity. The rate of the target's randomized movement can also be modified (ie. random direction every timestep or only change direction so often). + +The simulator supports multiple bounds-handling schemes. By default, dots will simply not move past the edges. Alternatively, the bound_hand parameter can be set to 'bounce', for a geometric reflection off the edges, or 'trans' which will have a mirrored result: a geometric translation to the opposite side of the grid. + +To add further complexity, additional targets can be added as desired via the dots parameter, and the herrs parameter can be set to generate multiple "red herrings" as distraction dots. The speed of the dots' movements can also be set; it is 1 by default. + +

+DotTraceSample +

+>The grid visuals provided by the render function will double the value of the network dot; this is a visual aid only, invisible to the network. + + +### Reward Functions + +This simulator supports multiple reward functions (aka. fitness functions): +- Euclidean (fit_func='euc'): the default option, this function computes the Euclidean (aka. Pythagorean) distance between the network dot and the target dot. +- Displacement (fit_func='disp'): this option computes the x,y displacement of the network dot with respect to the target dot, returning an x,y tuple. Currently, BindsNET only supports single reward values. To use this one, either be creative or update the network code... +- Range Rings (fit_func='rng'): this option uses the Euclidean distance and groups it into range rings. The radial distance of the range rings can be set by the ring_size parameter. +- Directional (fit_func='dir'): the directional option checks to see if the network's decision moved its dot closer, laterally, or further away from the target dot's prior position (ie. before applying movement this timestep) and returns a +1, 0, or -1 accordingly. + +Additionally, upon a successful intercept, the network will receive +10 if the bullseye parameter is active, and its dot will be teleported to another random location if the teleport parameter is active. + +>In the event multiple target dots are generated, the fitness functions only compute rewards with respect to the first target dot. + + +### Additional Features + +The environment can take a seed for random number generation in python, numpy, and Pytorch; otherwise, it will generate and save a new seed based on the current system time. + +As this simulator was developed in Anaconda Spyder on Windows, it can be run from Windows or Linux. Since environments handle plotting differently, and experiments can sometimes be terminated prematurely, this environment supports the recording of grid observations in text files and post-op plotting. Live rendering can also be disabled via the mute parameter, and a text-based alternative using pandas dataframe formatting can be enabled via the pandas parameter. + +Filenames and file paths can be specified for recording grid observations. By default, the filenames will be "grid" followed by "s#_$.csv" where # is the random seed used and $ is the current file number. addFileSuffix(suffix) adds the provided suffix (typically used for "train" or "test") to the filename, and changeFileSuffix(sFrom, sTo) will find sFrom in the filename and replace it with sTo. + +To ensure that files do not become too large to either be saved or be practically useful, cycleOutFiles(newInt) can be used to cycle the current save file, incrementing the file number suffix, or resetting it if newInt is set to a positive number. + +Post-op plotting is supported by dotTrace_plotter.py in the analysis directory. By default, this tool searches the examples directory for csvs in "out" directories, but that path can be easily changed. It supports plotting ranges of grid observations, reward plots, and performance plots. See below for an example of recording reward and performance data for plotting purposes. + + +### Example +See dot_tracing.py for an example in using the Dot Simulator for training an SNN in BindsNET. + +dot_tracing trains a basic RNN network on the dot simulator and demonstrates how to record reward and performance data (if desired) and plot spiking activity via monitors. + + diff --git a/bindsnet/environment/cue_reward.py b/bindsnet/environment/cue_reward.py new file mode 100644 index 00000000..857b91db --- /dev/null +++ b/bindsnet/environment/cue_reward.py @@ -0,0 +1,181 @@ +import numpy as np +import random +from time import time + +import torch + + +# Number of cues to be used in the experiment. +NUM_CUES = 4 + + +class CueRewardSimulator: + """ + This simulator provides basic cues and rewards according to the + network's choice, as described in the Backpropamine paper: + https://openreview.net/pdf?id=r1lrAiA5Ym + + :param epdur: int: duration (timesteps) of an episode; default = 200 + :param cuebits: int: max number of bits to hold a cue (max value = 2^n for n bits) + :param seed: real: random seed + :param zprob: real: probability of zero vectors in each trial. + """ + + def __init__(self, **kwargs) -> None: + self.ep_duration = kwargs.get("epdur", 200) # episode duration in timesteps + self.cuebits = kwargs.get("cuebits", 20) + self.seed = int(kwargs.get("seed", time())) + self.zeroprob = kwargs.get("zprob", 0.6) + + # zero array consists of the binary cue vector + four other fields. + self.zeroArray = np.zeros(self.cuebits + 4, dtype="int32") + + assert ( + 0.0 <= self.zeroprob and self.zeroprob <= 1.0 + ), "zprob must be valid probability" + + def make(self, name): + self.reset() # Simply reset according to grid definition. + + def step(self, action): + """ + Every trial, we randomly select two of the four cues to provide to the + network. Every timestep within that trial we either randomly display + only zeros, or we alternate between the two cues in the pair. + + At the end of a trial, we provide the response cue, for which the network + must respond 1 if the target was one of the provided cues or 0 if it was + not. The next timestep, we evaluate the response, giving a reward of 1 + for correct and -1 for incorrect. + + :param action: network's decision if the target cue was displayed. + :return obs: observation of vector with binary cue and the following fields: + - time since start of episode + - one-hot-encoded value for a response of 0 in previous timestep + - one-hot-encoded value for a response of 1 in previous timestep + - reward of previous timestep + :return reward: 1 for correct response; -1 for incorrect response. + :return done: indicates termination of simulation + :return info: dictionary including values for debugging purposes. + """ + self.tstep += 1 # increment episode timestep + self.trialTime -= 1 # decrement current trial timestep + + # Populate base fields of observation. + self.obs = self.zeroArray # default to empty array + self.obs[-4] = self.tstep # time since start of episode. + self.obs[-3] = int(self.response == 0) # response = 0 for previous timestep + self.obs[-2] = int(self.response == 1) # response = 1 for previous timestep + self.obs[-1] = self.reward[0] # reward of previous timestep + + self.response = action # Remember previous response + self.reward[0] = 0 # default current reward to 0 + + # If starting a new trial + if self.trialTime <= 0: + # Set new trial length, based on mean number of trials per episode = 15 + self.trialTime = random.randint(10, 20) // self.ep_duration + + # Randomly select the pair of cues to be shown to the network. + self.pairmask = np.array(range(NUM_CUES))[ + np.argsort(np.random.uniform(0, 1, 4)) < 2 + ] + + # Determine if target is one of these current cues displayed. + self.targ_disp = int(np.any(self.pairmask == self.target)) + + self.cue_pair_ind = 0 # Reset cue pair indicator + + # Deterministic special cases for last two trial timesteps. + if self.trialTime <= 2: + # If it's the second to last trial timestep, cue response. + # Response cue is another binary cue vector but with a value of 1. + if self.trialTime == 2: + self.obs[0] = 1 + + # If it's the last trial timestep, provide empty input, check + # check the answer to the response cue, and compute the reward. + else: + # reward = 1 for correct and -1 for incorrect. + self.reward[0] = int(action == self.targ_disp) * 2 - 1 + + # Else, roll the dice for a zero vector. + elif np.random.uniform(0, 1) > self.zeroprob: + # If we're not providing a zero vector, present one of + # the current cue pair and switch turns for next time. + self.obs[:-4] = self.cues[self.pairmask][self.cue_pair_ind] + self.cue_pair_ind = (self.cue_pair_ind + 1) % 2 + + done = self.ep_duration <= self.tstep + info = { + "target": self.target, + "pairmask": self.pairmask, + "targ_disp": self.targ_disp, + } + + return self.obs, self.reward, done, info + + def reset(self): + """ + Reset reset RNG seed; generate new cue bit arrays, and arbitrarily + select one of the four cues as the "target" cue. + """ + # Re-seed random functions + random.seed(self.seed) + np.random.seed(self.seed) + + # Reset timesteps + self.tstep = 0 + self.trialTime = 1 + + # Initialize cue bit strings + CUE_MAX = pow(2, self.cuebits) + cues_ints = np.zeros(NUM_CUES, dtype="int32") + for i in range(NUM_CUES): + c = 0 + while np.any(cues_ints == c): + c = random.randint(2, CUE_MAX) # 1 reserved for response cue + cues_ints[i] = c + + self.cues = np.zeros((NUM_CUES, self.cuebits), dtype="int32") + for i in range(NUM_CUES): + binarray = np.array(list(np.binary_repr(cues_ints[i]))).astype("int32") + self.cues[i][: len(binarray)] = binarray + + # Randomly select the target cue for this episode. + self.target = random.randint(0, NUM_CUES) + + # provide empty default observation + self.obs = self.zeroArray + + # Reset reward + self.reward = torch.Tensor(1) + self.reward[0] = 0 # default reward to 0 + + # Instantiate response member, defaulting to 0. + self.response = 0 + + return self.obs + + def render(self): + """ + Display current input vector. + """ + print(self.obs) + + +def driver(): + steps = 200 + cueSim = CueRewardSimulator() + cueSim.reset() + + observations = np.zeros((steps, 24), dtype="int32") + for t in range(steps): + observations[t], reward, done, info = cueSim.step(0) + + meanReward = observations[:, -1][observations[:, -1] != 0].mean() + print("Mean reward:", meanReward) + + +if __name__ == "__main__": + driver() diff --git a/bindsnet/environment/dot_simulator.py b/bindsnet/environment/dot_simulator.py new file mode 100644 index 00000000..e33cf160 --- /dev/null +++ b/bindsnet/environment/dot_simulator.py @@ -0,0 +1,526 @@ +import numpy as np +import os +import pandas as pd +import random +from time import time + +import torch +from gym import spaces +import matplotlib.pyplot as plt + + +# Mappings for changing direction if reflected. +# Cannot cross a row boundary moving right or left. +ROW_CROSSING = { + 1: 2, + 3: -2, + 5: 1, + 6: -1, + 7: 1, + 8: -1, +} + +# Cannot cross a column boundary moving up or down. +COL_CROSSING = { + 2: 2, + 4: -2, + 5: 3, + 6: 1, + 7: -1, + 8: -3, +} + + +class Dot: + def __init__(self, r: int = 0, c: int = 0, t: int = 1) -> None: + # Initialize current point and tail to initial point. + self.row = np.ones(t, dtype="int32") * r + self.col = np.ones(t, dtype="int32") * c + + def move(self, r: int, c: int): + """ + Cycle path history and set new current coordinates. + + :param r: row + :param c: column + """ + + # Cycle the path history. + for t in reversed(range(1, len(self.row))): + self.row[t] = self.row[t - 1] + self.col[t] = self.col[t - 1] + + # Set new current point + self.row[0] = r + self.col[0] = c + + +class DotSimulator: + """ + This simulator lets us generate dots and make them move. + It's especially useful in keeping entitled cats occupied, + but instead of feline neurons, we use this for fake ones. + + Specifically, this generates a grid for each timestep, where a specified + number of points have values of 1 with fading tails ("decay"), designating + the current positions and movements of their corresponding dots. All other + points are set to 0. From timestep to timestep, the dots either remain + where they are or move one space. + + The 2D observation of the current state is provided every step, as well as + the reward, completion flag, and sucessful interception flag. It may be + helpful to amplify the grid values when encoding them as spike trains. + + :param t: int: number of timesteps/samples of grids with dot movements + :param height: int: height dimension of the grid (rows) + :param width: int: width dimension of the grid (columns) + :param decay: int: length of decaying tail behind a dot (its path history) + :param dots: int: number of target dots + :param herrs: number of distraction dots (red herrings) + :param pandas: Bool: print as pandas dataframe versus graphical plots + :param write: Bool: write grids to file to be plotted later. + :param mute: Bool: mute graphical rendering (can write to file or print pandas) + :param speed: int: set movement speed of dots. + :param randr: float: set the randomization rate of target movements. + :param allow_stay: bool: allow a dot to remain in place as a movement choice. + :param seed: int: optional seed for RNG in movement generation. + :param fname: string: optional filename for saving grids to file + :param fpath: string: optional file path for saving grids to file + :param diag: Bool: allow diagonal movement. + :param bound_hand: str: bounds handling when a dot reaches the world's end. + 'stay': dots will simply be prevented from crossing the edges. + 'bounce': dot positions and directions will be reflected. + 'trans': dot positions will be mirrored to the opposite edge. + :param fit_func: str: Fitness function. + 'euc': Single Euclidean (Pythagorean) distance value + 'disp': Tuple of x,y displacement values + 'rng' : Range rings--the closer the ring, the lower the number + 'dir' : directional--+1 if moving in the right direction + -1 if moving in the wrong direction + 0 if neither. + :param ring_size: int: set range ring size for range ring fitness function. + :param bullseye: int: set reward for successful intercept; default = 10.0 + :param teleport: Bool: teleport network dot after intercept; default = true + """ + + def __init__(self, t: int, **kwargs) -> None: + + self.timesteps = t # total timesteps + self.ts = 0 # initialize current timestep to 0 + + """ Keyword arguments """ + self.h = kwargs.get("height", 28) # height dimension + self.w = kwargs.get("width", 28) # width dimension + self.decay = kwargs.get("decay", 1) # length of a dot's tail (path history) + self.ndots = kwargs.get("dots", 1) # Number of dots + self.herrs = kwargs.get("herrs", 0) # Red herrings (distractions) + self.pandas = kwargs.get("pandas", False) # print as pandas DF + self.write2F = kwargs.get("write", False) # write grids to file + self.mute = kwargs.get("mute", False) # mute displayed rendering + self.speed = kwargs.get("speed", 1) # dot movement speed + self.randr = kwargs.get("randr", 1.0) # rate of random movement + self.minch = int(not kwargs.get("allow_stay", True)) # allow stay choice. + + # Grab system time and trim off extra large parts of the number. + sysTime = time() + sysTime = int(1e10 * (sysTime - 1e6 * (sysTime // 1e6))) + + # Save off RNG seed. + self.seed = kwargs.get("seed", 0) + if self.seed == 0: + self.seed = sysTime + + # Create filename if one isn't provided. + path = kwargs.get("fpath", "out") + self.filename = kwargs.get("fname", "grids_s" + str(self.seed) + ".csv") + self.fileCnt = 0 + self.filename = ( + path + + "/" + + self.filename[:-5] + + "_" + + str(self.fileCnt) + + self.filename[-4:] + ) + if not os.path.exists(path): + os.makedirs(path) + + # Expand movement options if diagonal is allowed. + if kwargs.get("diag", False): + self.choices = 9 + else: + self.choices = 5 + + # Enumerated bounds handling when a dot traverses the region's edge. + bh = kwargs.get("bound_hand", "stay") + if bh == "stay": + self.b_handling = 0 # Don't move if directed past the edge. + elif bh == "bounce": + self.b_handling = 1 # Bounce off edge (reflect coordinates). + elif bh == "trans": + self.b_handling = 2 # Translate to opposite side of the region. + else: + assert False, "Unsupported bounds handling" + + # Enumerated fitness (reward) function. + ff = kwargs.get("fit_func", "euc") + if ff == "euc": + self.fit_func = 0 # Single Euclidean (Pythagorean) distance value + elif ff == "disp": + self.fit_func = 1 # Tuple of x,y displacement values + elif ff == "rng": + self.fit_func = 2 # Range rings--the closer the ring, the lower the number + elif ff == "dir": + self.fit_func = 3 # direction--moving closer or farther away? + else: + assert False, "Unsupported fitness function" + + self.ring_size = kwargs.get("ring_size", 2) # Range ring size + self.bullseye = kwargs.get("bullseye", 10.0) # Intercept reward + self.teleport = kwargs.get("teleport", True) # Teleport after intercept + + # Initialize empty lists of relevant and distraction dots. + self.netDot = Dot(0, 0, self.decay) + self.dots = [] + self.herrings = [] + self.obs = np.zeros((self.h, self.w)) + + self.action_space = spaces.Discrete(self.choices) + + self.newPlot = True # One-time flag + + def step(self, action): + """ + Generates a grid for the current timestep. + See above for full description. + + :param action: network's prediction of the dot movement + :return obs: observation of grid matrix of shape (h,w) + :return reward: precision of network's prediction in Euclidean distance + :return done: indicates termination of simulation + :return intercept: indicates a successful intercept this step + """ + + # Increment timestep + self.ts += 1 # Increment timestep + + # If the random rate is high enough, update movement direction. + if random.uniform(0, 1) <= self.randr: + self.dotDir = random.randint( + self.minch, self.choices - 1 + ) # five possible options + + # Initialize empty grid and populate as we update dots. + self.obs = np.zeros((self.h, self.w)) + + # Update network dot according to the network's action. + self.prevRow = self.netDot.row[0] + self.prevCol = self.netDot.col[0] + if action is not None: + self.movePoint(self.netDot, action) + + # self.obs = self.obs/(self.ndots + self.herrs) # normalize + reward, intercept = self.compute_reward() + + # Teleport network dot if intercept is successful. + if intercept and self.teleport: + bh1, bh2 = self.h // 5, 4 * self.h // 5 + bw1, bw2 = self.w // 5, 4 * self.w // 5 + r, c = random.randint(bh1, bh2), random.randint(bw1, bw2) + self.netDot = Dot(r, c, self.decay) + + # Redo grid observation. + self.obs = np.zeros((self.h, self.w)) + self.obs[r, c] = 1 + + # Move all relevant dots in the same direction. + for d in self.dots: + self.movePoint(d) + + # Move distraction dots with individually randomized motion. + for h in self.herrings: + self.movePoint(h, random.randint(self.minch, self.choices - 1)) + + reward = torch.Tensor(np.array(reward)) + done = self.timesteps <= self.ts + + return self.obs, reward, done, intercept + + def reset(self): + """ + Reset dots to initial positions, and reset RNG seed. + """ + + self.ts = 0 # reset timesteps + + # Reset RNG + random.seed(self.seed) + + # provide default observation + self.obs = np.zeros((self.h, self.w)) + + # Set boundaries (so we don't spawn points on the edge. + # Not that there's a real problem with it, but it's boring. + bh1, bh2 = self.h // 5, 4 * self.h // 5 + bw1, bw2 = self.w // 5, 4 * self.w // 5 + + # Start dots in the middle. + # midr = self.h//2 <= we randomize this now. + # midc = self.w//2 + + # We know that the sum from n=0 to n=N of 1 - n/N = (N + 1)/2 + # Thus, computing the initial grid space for a dot, given the + # length of its tail N would be (self.decay + 1)/2. + # But... we also have to cap it at 1. So, who cares? + + # Reinitalize network dot placement. + r, c = random.randint(bh1, bh2), random.randint(bw1, bw2) + self.netDot = Dot(r, c, self.decay) + self.obs[r, c] = 1 + + # Reinitalize target dot placement with initial movement direction. + self.dots = [] + self.dotDir = random.randint(self.minch, self.choices - 1) + for d in range(self.ndots): + r, c = random.randint(bh1, bh2), random.randint(bw1, bw2) + self.dots.append(Dot(r, c, self.decay)) + self.obs[r, c] = 1 + + # Reinitalize red herring placement. + self.herrings = [] + for h in range(self.herrs): + r, c = random.randint(bh1, bh2), random.randint(bw1, bw2) + self.herrings.append(Dot(r, c, self.decay)) + self.obs[r, c] = 1 + + return self.obs + + def movePoint(self, d: Dot, dotDir: int = -1): + """ + Apply clockwise directional enumeration. + + :param dotDir: enumerated movement as described above. + :param/return r: current row => next row + :param/return c: current column => next column + """ + + # If not provided, use the known current direction. + targetDir = False # flag if we're using the target's direction. + if dotDir < 0: + dotDir = self.dotDir + targetDir = True + + r, c = d.row[0], d.col[0] + + """ Apply clockwise directional enumeration. """ + # 0 means stay, though we also won't go past the edge. + if dotDir == 1: # up + r += self.speed + elif dotDir == 2: # right + c += self.speed + elif dotDir == 3: # down + r -= self.speed + elif dotDir == 4: # left + c -= self.speed + elif dotDir == 5: # up and right + r += self.speed + c += self.speed + elif dotDir == 6: # down and right + r -= self.speed + c += self.speed + elif dotDir == 7: # down and left + r -= self.speed + c -= self.speed + elif dotDir == 8: # up and left + r += self.speed + c -= self.speed + elif dotDir != 0: # Woops + assert False, "Unsupported dot direction" + + """ When a dot attempts to move past an edge... """ + # Stay put. + if self.b_handling == 0: + r = max(min(r, self.h - 1), 0) + c = max(min(c, self.w - 1), 0) + # direction stays the same. + + # Bounce: reflect its coordinates back into the region. + elif self.b_handling == 1: + if r < 0 or self.h <= r: + r = self.h - 1 - r % self.h # reflect row + if targetDir: + self.dotDir += ROW_CROSSING[dotDir] + + if c < 0 or self.w <= c: + c = self.w - 1 - c % self.w # reflect column + if targetDir: + self.dotDir += COL_CROSSING[dotDir] + + # Translate: the dot will continue in the same direction + # from the opposite side of the region. + elif self.b_handling == 2: + r = r % self.h # Mirror row + c = c % self.w # Mirror column + # direction stays the same. + + # Woops + else: + assert False, "Unsupported bounds handling" + + # Update the saved point in the Dot class. + # This also cycles the path history. + d.move(r, c) + + # Update the grid with this point and its decaying trail. + for t in range(self.decay): + self.obs[d.row[t], d.col[t]] = min( + self.obs[d.row[t], d.col[t]] + 1 - t / self.decay, 1 + ) + + def compute_reward(self): + """ + Computes reward according to the chosen fitness function. + Returns reward and flag indicating a successful intercept. + """ + # Add bull's eye reward (if we're using it) + if ( + self.bullseye != 0 + and self.dots[0].row[0] == self.netDot.row[0] + and self.dots[0].col[0] == self.netDot.col[0] + ): + return self.bullseye, True + + reward = 0.0 + + # Euclidean distance + if self.fit_func == 0: + reward = -np.hypot( + self.dots[0].row[0] - self.netDot.row[0], + self.dots[0].col[0] - self.netDot.col[0], + ) + + # Displacement tensor + elif self.fit_func == 1: + reward = torch.Tensor( + [ + self.dots[0].row[0] - self.netDot.row[0], + self.dots[0].col[0] - self.netDot.col[0], + ] + ) + + # Range rings; default range ring size = 2 + elif self.fit_func == 2: + reward = ( + -np.hypot( + self.dots[0].row[0] - self.netDot.row[0], + self.dots[0].col[0] - self.netDot.col[0], + ) + // self.ring_size + ) + + # Directional + elif self.fit_func == 3: + rd1 = abs(self.dots[0].row[0] - self.prevRow) + rd2 = abs(self.dots[0].row[0] - self.netDot.row[0]) + cd1 = abs(self.dots[0].col[0] - self.prevCol) + cd2 = abs(self.dots[0].col[0] - self.netDot.col[0]) + + if rd2 < rd1: + reward += 1.0 # right row movement + elif rd1 < rd2: + reward -= 1.0 # wrong row movement + if cd2 < cd1: + reward += 1.0 # right col movement + elif cd1 < cd2: + reward -= 1.0 # wrong col movement + + # Woops + else: + assert False, "Unsupported fitness function" + + return reward, False + + def render(self): + """ + Display current state, either in ASCII or graphic plots. + """ + + # Double value of network dot only for visual aid in rendering. + temp = self.obs + temp[self.netDot.row, self.netDot.col] *= 2 + + # Write to file if requested. + if self.write2F: + f = open(self.filename, "ab") + np.savetxt(f, temp, delimiter=",") + f.close() + + # Print as pandas dataframe if requested. + if self.pandas: + print("Timestep:", self.ts) + print(pd.DataFrame(temp, dtype="uint32")) + + # Provide graphical rendering if requested. + if not self.mute: + # Otherwise, we'll render it as... I don't know yet. + # print('Timestep:', self.ts) + # get current figure, clear it, and replot. + if self.newPlot: + self.fig = plt.gcf() + + plt.figure(self.fig.number) + plt.clf() + plt.imshow(temp, cmap="hot", interpolation="nearest") + + # Only display colorbar once. + if self.newPlot: + self.newPlot = False + # plt.ion() + plt.colorbar() + + # Pause so that that GUI can do its thing. + plt.pause(1e-8) + + def cycleOutFiles(self, newInt=-1): + """ + Increments numbered suffix on output file to start a new one. + """ + oldStr = "_" + str(self.fileCnt) + if 0 <= newInt: + self.fileCnt = newInt + else: + self.fileCnt += 1 + self.filename = self.filename.replace(oldStr, "_" + str(self.fileCnt)) + + def addFileSuffix(self, suffix): + """ + Adds suffix to output file (like "train" or "test"). + """ + self.filename = self.filename[:-5] + suffix + "_" + self.filename[-5:] + + def changeFileSuffix(self, sFrom, sTo): + """ + Adds suffix to output file (like "train" or "test"). + """ + self.filename = self.filename.replace(sFrom, sTo) + self.cycleOutFiles(newInt=0) # reset file count. + + +def driver(): + steps = 200 + dotSim = DotSimulator(200) + dotSim.reset() + + grids = np.zeros((steps, 28, 28)) + directions = np.zeros(steps) + done = False + for t in range(steps): + grids[t], reward, done, info = dotSim.step(0) + directions[t] = info["direction"] + + vals, cnts = np.unique(directions, return_counts=True) + print(vals, cnts / steps) + + +if __name__ == "__main__": + driver() diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index cd72334b..c8d693ae 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -36,7 +36,7 @@ def __init__( :param nu: Single or pair of learning rates for pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the batch dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. """ # Connection parameters. self.connection = connection @@ -106,7 +106,7 @@ def __init__( :param nu: Single or pair of learning rates for pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the batch dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. """ super().__init__( connection=connection, @@ -148,7 +148,7 @@ def __init__( :param nu: Single or pair of learning rates for pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the batch dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. """ super().__init__( connection=connection, @@ -264,7 +264,7 @@ def __init__( :param nu: Single or pair of learning rates for pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the batch dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. """ super().__init__( connection=connection, @@ -402,7 +402,7 @@ def __init__( :param nu: Single or pair of learning rates for pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the batch dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. """ super().__init__( connection=connection, @@ -508,7 +508,7 @@ def __init__( respectively. :param reduction: Method for reducing parameter updates along the minibatch dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. Keyword arguments: @@ -702,7 +702,7 @@ def __init__( respectively. :param reduction: Method for reducing parameter updates along the minibatch dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. Keyword arguments: @@ -906,7 +906,7 @@ def __init__( respectively. :param reduction: Method for reducing parameter updates along the minibatch dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. Keyword arguments: diff --git a/docs/DotTraceSample.png b/docs/DotTraceSample.png new file mode 100644 index 00000000..b5b753f9 Binary files /dev/null and b/docs/DotTraceSample.png differ diff --git a/examples/README.md b/examples/README.md index 40dfd040..6b04989b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -139,3 +139,33 @@ parameters|default|description --dt |1| SNN simulation time increment (in ms) --seed |0| initial random seed --tensorboard |True| whether to use **Tensorboard** or **Matplotlib** analyzer as output + +*** + +## Dot Tracing example +*/examples/dotTracing/* +dot_tracing.py trains a basic RNN on the Dot Simulator environment and demonstrates how to record reward and performance data (if desired) and plot spiking activity via monitors. + +See the environments directory for documentation on the Dot Simulator. + +parameters|default|description +-|:-:|- +--steps |100| number of timesteps in an episode +--dim |28| square dimensions of grid (actual simulator can specify rows vs columns) +--dt |1| SNN simulation time increment (in ms) +--seed |0| initial random seed +--granularity |100| spike train granularity +--trn_eps |1000| training episodes +--tst_eps |100| test episodes +--decay |4| length of decaying trail behind target dot +--diag |False| enables diagonal movements +--randr |.15| determines rate of randomization of movement +--boundh |'bounce'| bounds handling mode +--fit_func |'dir'| fitness function, defaulted to directional +--allow_stay |False| disable option for targets to remain in place +--pandas |False| true = pandas Dataframe printout; false = heatmap +--mute |False| prohibit graphical rendering +--write |True| save observed grids to file +--fcycle |100| number of episodes per save file +--gpu |True| Utilize cuda +--herrs |0| number of distraction dots diff --git a/examples/dotTracing/dot_tracing.py b/examples/dotTracing/dot_tracing.py new file mode 100644 index 00000000..6447cdb8 --- /dev/null +++ b/examples/dotTracing/dot_tracing.py @@ -0,0 +1,311 @@ +from bindsnet.network import Network + +# from bindsnet.pipeline import EnvironmentPipeline +# from bindsnet.learning import MSTDP +from bindsnet.learning import MSTDPET +from bindsnet.learning import PostPre + +# from bindsnet.encoding import bernoulli +from bindsnet.encoding import poisson +from bindsnet.network.topology import Connection +from bindsnet.environment.dot_simulator import DotSimulator +from bindsnet.network.nodes import Input, LIFNodes + +# from bindsnet.pipeline.action import select_softmax +# from bindsnet.network.nodes import AbstractInput +from bindsnet.network.monitors import Monitor +from bindsnet.analysis.plotting import ( + plot_spikes, + # plot_performance +) + +import argparse +import numpy as np +import time +import torch + + +# Handle arguments for dot tracing params. +parser = argparse.ArgumentParser() +parser.add_argument("--steps", type=int, default=100) +parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--dim", type=int, default=28) +parser.add_argument("--granularity", type=int, default=100) +parser.add_argument("--neurons", type=int, default=100) +parser.add_argument("--dt", type=int, default=1.0) +parser.add_argument("--trn_eps", type=int, default=1000) +parser.add_argument("--tst_eps", type=int, default=100) +parser.add_argument("--decay", type=int, default=4) +parser.add_argument("--herrs", type=int, default=0) +parser.add_argument("--diag", type=bool, default=False) +parser.add_argument("--randr", type=float, default=0.15) +parser.add_argument("--boundh", type=str, default="bounce") +parser.add_argument("--fit_func", type=str, default="dir") +parser.add_argument("--allow_stay", type=bool, default=False) +parser.add_argument("--pandas", type=bool, default=False) +parser.add_argument("--mute", type=bool, default=False) +parser.add_argument("--write", type=bool, default=True) +parser.add_argument("--fcycle", type=int, default=100) +parser.add_argument("--gpu", type=bool, default=True) + +args = parser.parse_args() + +steps = args.steps # timesteps in which the dot is moving +dim = args.dim # 28x28 square +granularity = args.granularity # granularity (or precision) of spike trains +neurons = args.neurons # Number of neurons in hidden layer +dt = args.dt # delta time of network +trn_eps = args.trn_eps # training episodes +tst_eps = args.tst_eps # testing episodes +decay = args.decay # length of decaing tail behind a dot +herrs = args.herrs # distraction dots +diag = args.diag # allows diagonal movement +randr = args.randr # determines rate of randomization of movement +boundh = args.boundh # bounds handling mode +fit_func = args.fit_func # fitness function +allow_stay = args.allow_stay # disable option for targets to remain in place. +pandas = args.pandas # true = pandas DF printout; false = heatmap +mute = args.mute # prohibit graphical rendering +write = args.write # write observed grids to file. +fcycle = args.fcycle # number of episodes per save file +gpu = args.gpu # Utilize cuda + +if diag: + moveChoices = 9 +else: + moveChoices = 5 + +""" Set some globals """ +# Set processor type +if torch.cuda.is_available() and gpu: + DEVICE = torch.device("cuda") +else: + DEVICE = torch.device("cpu") + +# Set neural network layer names +LAYER1 = "Input" +LAYER2 = "Hidden" +LAYER3 = "Output" + +# file path for recording grid observations, rewards, and performance. +OUT_FILE_PATH = "out/" + + +def genFileName(ftype, suffix=""): + """ + Generates output file names for rewards and performance + """ + # Grab system time and trim off extra large parts of the number. + sysTime = time.time() + sysTime = int(1e10 * (sysTime - 1e6 * (sysTime // 1e6))) + + # Create filename if one isn't provided. + return OUT_FILE_PATH + ftype + "_s" + str(sysTime) + "_" + suffix + ".csv" + + +def runSimulator(net, env, spikes, episodes, gran=100, rfname="", pfname=""): + + steps = env.timesteps + dt = net.dt + spike_ims, spike_axes = None, None + + # For each episode... + for ep in range(episodes): + # Reset variables for new episode. + total_reward = 0 + rewards = np.zeros(steps) + intercepts = 0 + step = 0 + net.reset_state_variables() + env.reset() + done = False + env.render() + clock = time.time() + + # Initialize action tensor, network output monitor, and spike train record. + action = torch.randint(low=0, high=env.action_space.n, size=(1,))[0] + spike_record = torch.zeros( + (steps, int(gran / dt), env.action_space.n), device=DEVICE + ) + # perf_ax = None + + # Run through episode. + while not done: + + step += 1 + obs, reward, done, intercept = env.step(action) + obs = torch.Tensor(obs).to(DEVICE) + reward = reward.to(DEVICE) + + # Determine the action probabilities + probabilities = torch.softmax( + torch.sum(spike_record[step - 1 % steps], dim=0), dim=0 + ) + action = torch.multinomial(probabilities, num_samples=1).item() + + # Place the observations into the inputs. + obs = obs.unsqueeze(0) + inputs = {LAYER1: poisson(obs * 5e2, gran, dt, device=DEVICE)} + if DEVICE == "cuda": + inputs = {k: v.cuda() for k, v in inputs.items()} + + # Run the network on the spike train-encoded inputs. + net.run(inputs=inputs, time=gran, reward=reward) + spike_record[step % steps] = spikes[LAYER3].get("s").squeeze() + rewards[step - 1] = reward.item() + + # record successful intercept + if intercept: + intercepts += 1 + + if done: + # Update network with cumulative reward + if net.reward_fn is not None: + net.reward_fn.update(accumulated_reward=total_reward, steps=step) + + # Save rewards thus far to file + if rfname != "": + f = open(rfname, "ab") + np.savetxt(f, rewards, delimiter=",", fmt="%.6f") + f.close() + + spikes_ = {layer: spikes[layer].get("s").view(gran, -1) for layer in spikes} + spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) + # perf_ax = plot_performance(reward, x_scale=10, ax=perf_ax) + + env.render() + total_reward += reward + + if step % 10 == 0: + print( + f"Iteration: {step} (Time: {time.time() - clock:.4f}); reward: {reward}" + ) + clock = time.time() + + print(f"Episode {ep} total reward:{total_reward}") + # Save intcercepts thus far to file + if pfname != "": + f = open(pfname, "a+") + if 0 < ep: + f.write("," + str(intercepts)) + else: + f.write(str(intercepts)) + f.close() + + # Cycle output files every 10000 iterations + if ep % fcycle == 0: + env.cycleOutFiles() + + +def main(): + + # Build network. + network = Network(dt=dt) + + # Input Layer + inpt = Input(n=dim * dim, shape=[1, 1, 1, dim, dim], traces=True) + + # Hidden Layer + middle = LIFNodes(n=neurons, traces=True) + + # Ouput Layer + out = LIFNodes(n=moveChoices, refrac=0, traces=True) + + # Connections from input layer to hidden layer + inpt_middle = Connection(source=inpt, target=middle, wmin=0, wmax=1) + + # Connections from hidden layer to output layer + middle_out = Connection( + source=middle, + target=out, + wmin=0, # minimum weight value + wmax=1, # maximum weight value + update_rule=MSTDPET, # learning rule + nu=1e-1, # learning rate + norm=0.5 * middle.n, # normalization + ) + + # Recurrent connection, retaining data within the hidden layer + recurrent = Connection( + source=middle, + target=middle, + wmin=0, # minimum weight value + wmax=1, # maximum weight value + update_rule=PostPre, # learning rule + nu=1e-1, # learning rate + norm=5e-3 * middle.n, # normalization + ) + + # Add all layers and connections to the network. + network.add_layer(inpt, name=LAYER1) + network.add_layer(middle, name=LAYER2) + network.add_layer(out, name=LAYER3) + network.add_connection(inpt_middle, source=LAYER1, target=LAYER2) + network.add_connection(middle_out, source=LAYER2, target=LAYER3) + network.add_connection(recurrent, source=LAYER2, target=LAYER2) + network.to(DEVICE) + + # Add monitors + # network.add_monitor(Monitor(network.layers["Hidden"], ["s"], time=granularity), "Hidden") + # network.add_monitor(Monitor(network.layers["Output"], ["s"], time=granularity), "Output") + spikes = {} + for layer in set(network.layers): + spikes[layer] = Monitor( + network.layers[layer], + state_vars=["s"], + time=int(granularity / dt), + device=DEVICE, + ) + network.add_monitor(spikes[layer], name=layer) + + # Load the Dot Simultation environment. + environment = DotSimulator( + steps, + decay=decay, + herrs=herrs, + diag=diag, + randr=randr, + write=write, + mute=mute, + bound_hand=boundh, + fit_func=fit_func, + allow_stay=allow_stay, + pandas=pandas, + fpath=OUT_FILE_PATH, + ) + environment.reset() + + print("Training: ") + rewFile = genFileName("rew", "train") + perfFile = genFileName("perf", "train") + environment.addFileSuffix("train") + runSimulator( + network, + environment, + spikes, + episodes=trn_eps, + gran=granularity, + rfname=rewFile, + pfname=perfFile, + ) + + # Freeze learning + network.learning = False + + print("Testing: ") + rewFile = genFileName("rew", "test") + perfFile = genFileName("perf", "test") + environment.changeFileSuffix("train", "test") + runSimulator( + network, + environment, + spikes, + episodes=tst_eps, + gran=granularity, + rfname=rewFile, + pfname=perfFile, + ) + + +if __name__ == "__main__": + main()