Source code for params_proto.v2.hyper

import itertools
from collections import defaultdict, namedtuple
from contextlib import contextmanager
from typing import ContextManager, Dict, Iterable, TypeVar, Union

from params_proto.v2.proto import Meta, ParamsProto, Proto


[docs] def dot_join(*keys): """remove Nones from the keys, but not '',""" _ = [k for k in keys if k] if not _: return None return ".".join(_)
Item = namedtuple("Item", ["key", "value"])
[docs] def key_items(d, prefix=None): """ Takes in tuples of [key, value], or [None, [[key, value], ...]] returns wtf """ for k, vs in d: _ = dot_join(prefix, k) yield [Item(_, v) if _ else v for v in vs]
[docs] def flatten_items(row) -> Iterable[Item]: if isinstance(row, Item): yield row elif isinstance(row, Iterable): for item in row: yield from flatten_items(item) else: yield row
T = TypeVar("ParamsProto")
[docs] class Sweep: _d = None __original = None __noot = None __each_fn = None
[docs] def each(self, fn): self.__each_fn = fn return self
# noinspection PyProtectedMember def __init__(self, *protos: Meta): # the ParamsProto is updatable via proto._update(dot_dict) # use object itself as key if _prefix is missing self.root: Dict[str, ParamsProto] = {p._prefix or p: p for p in protos} self.stack = [[]] def __len__(self): return len(self.list) def __getitem__(self, item: Union[slice, int, float]): if isinstance(item, slice): assert item.step != 0, "step can not be zero." if ( (item.start and item.start < 0) or (item.stop and item.stop < 0) or (item.step and item.step < 0) ): for override in self.list[item]: for org, proto in zip(self.original, self.noot.values(), strict=False): proto._update(**org) proto._update(override) yield override return for i, el in enumerate(self): if item.start is not None and i < item.start: continue if item.step is None or (i - (item.start or 0)) % item.step == 0: yield el if item.stop is None: continue if i >= item.stop - 1: break elif isinstance(item, int): # need-test: Not tested from a quick glance. if item < 0: for override in self.list[item]: for org, proto in zip(self.original, self.noot.values(), strict=False): proto._update(**org) proto._update(override) yield override for i, el in enumerate(self): if i == item: for org, proto in zip(self.original, self.noot.values(), strict=False): proto._update(**org) proto._update(el) yield el break else: raise NotImplementedError(f"slicing is not implemented for {item}")
[docs] def items(self): return enumerate(self)
@property def list(self): """returns self as a list. Currently not idempotent. Might become idempotent in the future.""" return [*iter(self)] @property def dataframe(self): import pandas as pd return pd.DataFrame(self.list) @property def __dict__(self): if self._d: return self._d self._d = defaultdict(list) for config in self: for k, v in config.items(): self._d[k].append(v) return self._d @property def noot(self): from copy import deepcopy return deepcopy(self.root) @property def snack(self): from copy import deepcopy return deepcopy(self.stack) def __enter__(self): self.stack.append([]) for proto in self.root.values(): data = {} def set_hook(_, k, v, p=proto._prefix): # note: we wrap this value in Proto, so that we can distinguish # between true None vs a None value set by the user. data[k] = Proto(v) return self.set_param(k, [v], prefix=p) def get_hook(_, k, p=proto._prefix): # note: This is intern used in the PraramsProto clas, to decide on # overwride without key string filtering which is prone to errors. return data.get(k, None) proto._add_hooks(set_hook, get_hook) return self def __exit__(self, *args): for proto in self.root.values(): proto._pop_hooks() frame = self.stack.pop(-1) result = itertools.product(*key_items(frame)) self.set_param(None, result) @property def original(self): if self.__original is None: self.__original = [] for proto in self.noot.values(): # noinspection PyCallByClass def no_reset(k): return getattr(type.__getattribute__(proto, k), "accumulant", False) self.__original.append( {k: v for k, v in vars(proto).items() if not no_reset(k)} ) return self.__original def __iter__(self): for row in itertools.chain(*[it.value for it in self.snack[-1]]): override = dict(flatten_items(row)) for org, proto in zip(self.original, self.noot.values(), strict=False): proto._update(**org) # only apply those key-value pairs that appear in the original. proto._update( override if proto._prefix else {k: v for k, v in override.items() if k in org} ) if callable(self.__each_fn): with Sweep(*self.noot.values()) as sweep: self.__each_fn(*self.noot.values()) for deps in sweep: yield {k: v for k, v in itertools.chain(override.items(), deps.items())} else: yield override
[docs] def set_param(self, name, params, prefix=None): item = Item(dot_join(prefix, name), params) self.stack[-1].append(item)
@property @contextmanager def product(self) -> ContextManager[None]: self.stack.append([]) try: for proto in self.root.values(): prefix = proto._prefix proto._add_hooks(lambda _, *args, p=prefix: self.set_param(*args, prefix=p)) yield self finally: for proto in self.root.values(): proto._pop_hooks() frame = self.stack.pop(-1) result = itertools.product(*key_items(frame)) self.set_param(None, result) @property @contextmanager def zip(self) -> ContextManager[T]: self.stack.append([]) try: for proto in self.root.values(): prefix = proto._prefix proto._add_hooks(lambda _, *args, p=prefix: self.set_param(*args, prefix=p)) yield self finally: for proto in self.root.values(): proto._pop_hooks() frame = self.stack.pop(-1) result = list(zip(*key_items(frame), strict=False)) self.set_param(None, result) @property @contextmanager def set(self) -> ContextManager[T]: try: yield self.__enter__() finally: self.__exit__() @property @contextmanager def chain(self) -> ContextManager[T]: self.stack.append([]) try: for proto in self.root.values(): prefix = proto._prefix proto._add_hooks(lambda _, *args, p=prefix: self.set_param(*args, prefix=p)) yield self finally: for proto in self.root.values(): proto._pop_hooks() frame = self.stack.pop(-1) result = itertools.chain(*(value for k, value in frame)) self.set_param(None, result)
[docs] def save(self, filename="sweep.jsonl", overwrite=True, verbose=True): import json from termcolor import colored as c # todo: connect to ml-logger to setup managed sweep with open(filename, "w" if overwrite else "a+") as f: for item in self.list: f.write(json.dumps(item) + "\n") if verbose: import os from urllib import parse print( c("saved", "blue"), c(len(self.list), "green"), c("items to", "blue"), filename, ".", # this is to show file path in console. "file://" + parse.quote(os.path.realpath(filename)), )
[docs] @staticmethod def log(deps, filename): """append deps object to a JSONL log file, used as a helper function""" import json with open(filename, "a+") as f: f.write(json.dumps(deps) + "\n")
[docs] @staticmethod def read(filename): """Read JSONL log files, used as a helper function""" import json sweep = [] with open(filename, "r") as f: line = f.readline().strip() while line: # need to handle end of line if not line.startswith("//"): sweep.append(json.loads(line.strip())) line = f.readline().strip() return sweep
file = None
[docs] def load(self, file="sweep.jsonl", strict=True, silent=False): """ Loading sweep state from a jsonl file: Note: **Important Caveat** When multiple prefix-free ParamsProto objects are present, We sweep through all of the proto objects and sets the attribute to the first proto with the correct key. This first-attr approach works because the ParamsProto object also generates argparse parameters, which means repetitive arguments are not possible. However this would fail in cases where attributes are dynamically added to an argument object. The `sweep.jsonl` file loses this type of information, therefore there is no way to recover this type of attributes. So the user should try to either use PrefixProto, or explicitly define the attributes. Usage Pattern 1: Loading from a file:: sweep = Sweep(Args, RUN).load('sweep.jsonl') for i, deps in enumerate(sweep): assert RUN.job_id == i + 1, "the job_id in that sweep.json should be 1-based." Usage Pattern 2: Loading from a sweep list object or a pandas DataFrame:: sweep_list = Sweep.read(sweep.jsonl) sweep = Sweep(Args, RUN).load(sweep_list) for i, deps in enumerate(sweep): assert RUN.job_id == i + 1, "the job_id in that sweep.json should be 1-based." """ import pandas as pd from termcolor import colored self.file = file if isinstance(file, str): file = self.read(file) if isinstance(file, list): df = pd.DataFrame(file) elif isinstance(file, pd.DataFrame): df = file else: raise TypeError(f"{type(file)} is not supported") with self.zip: for full_key in df: prefix, *keys = full_key.split(".") if prefix in self.root: proto = self.root[prefix] if not hasattr(proto, keys[0]): if strict: raise KeyError(f'{proto} does not contain the key "{prefix}.{keys[0]}"') if not silent: print( colored(f'{proto} does not contain the key "', "red") + colored(f"{full_key}", "green") + colored('" ', "red") ) setattr(proto, ".".join(keys), df[full_key].values.tolist()) else: for k, proto in self.root.items(): if isinstance(k, str): continue if hasattr(proto, prefix): setattr(proto, full_key, df[full_key].values.tolist()) break else: if strict: raise KeyError( f'The key "{full_key}" does not appear in any of the Arguments' ) if not silent: print( colored('The key "', "red") + colored(f"{full_key}", "green") + colored('" ', "red") + colored("does not appear in any of the Arguments", "red") ) return self