import random
import collections
[docs]class GridSearch:
"""
Simple GridSearch to apply to a generic function
:param function: (function):
function to perform grid search on
:param grid_params: (dict):
dictionary mapping variable names to lists of possible inputs aka..\n
{'input_a':['dog', 'cat', 'stuff'],
'input_b':[3, 10, 22]}
:param search_behavior: (string):
how to perform the search.
Options are: 'exhaustive', 'sampled_x.x' (where `x.x` is sample threshold 0.0 < 1.0)\n
`exhaustive` - try every parameter in order they are specified in the dictionary (last key gets all its values searched first)\n
`sampled` - sample from the dictionary of params with specified threshold. The random tries *below* the threshold will be executed
:param args_as_dict: (bool):
There are two ways to pass parameters into a function:\n
1. Simply use each key in grid_params as a variable to pass to the function (and change those variable values according
to the mapping inside grid_params)\n
2. Pass a single dictionary to the function where the keys of the dictionary themselves are changed according to the
grid_params\n
defaults to dict
"""
def __init__(self, function, grid_params, search_behavior='exhaustive', args_as_dict=True):
self.func = function
self.args = grid_params
self.sampled_thresh = 1.0
if 'sampled_' in search_behavior:
behaviors = search_behavior.split('_')
self.behavior = behaviors[0]
self.sampled_thresh = float(behaviors[1])
else:
self.behavior = search_behavior
self.args_as_dict = args_as_dict
def _execute(self, input_args, available_args):
"""
Recursively reduce parameters and finally execute the function when all params have been selected
:param input_args:
dictionary into which to collect input arguments (used in the recursive call to keep just the needed params)
:param available_args:
list of available (arg_name, arg_values) tuples for the rest of the arguments
"""
if len(available_args) == 0: # We've reached the bottom of the recursive stack, execute function
doExecute = True
if self.behavior == 'sampled':
if random.random() > self.sampled_thresh:
doExecute = False
if doExecute:
if self.args_as_dict: # this passes ONE argument to the function which is the dictionary
self.func(input_args)
else:
self.func(**input_args) # this calls the function with arguments specified in the dictionary
# get all keys
keys = available_args.keys()
keys_to_remove = []
for i, key in enumerate(keys):
values = available_args.get(key)
# this is a list of possible inputs so iterate over it. Strings are iterable in python so filter out
if isinstance(values, collections.Iterable) and not isinstance(values, str):
# first, augment available_args so it no longer contains keys that we have already carried over
keys_to_remove.append(key)
for k in keys_to_remove:
available_args.pop(k)
for value in values:
input_args[key] = value
self._execute(input_args, available_args)
available_args[key] = values # replace values so they can be used in the next iterative call
break # don't do any more iterations after we handled the first key with multiple choices
input_args[key] = values
keys_to_remove.append(key)
if (i+1) == len(keys): # we've reached the final item in the available args
self._execute(input_args, {})
[docs] def run(self):
"""
Runs GridSearch by iterating over options as specified
:return:
"""
input_args = {}
self._execute(input_args, self.args)