Skip to content
Snippets Groups Projects
Commit 6116eccd authored by Joao's avatar Joao
Browse files

Integrate weigths and biases

parent 31ae57c0
Branches
No related tags found
No related merge requests found
......@@ -3,5 +3,6 @@ dist/
*.swp
__pycache__/
logs/
wandb/
*.egg-info/
......@@ -54,6 +54,11 @@ for env in envs:
**d,
some_default_param=some_default_param,
wandb_enabled=True,
wandb_entity='jaccarvalho',
wandb_project='test_experiment_launcher',
wandb_group=f'test_group-{env}-{a}={boolean_param}',
# A subdirectory will be created for parameters with a trailing double underscore.
env__=env,
a__=a,
......
import argparse
import os
import wandb
from experiment_launcher import run_experiment
from experiment_launcher.decorators import single_experiment
from experiment_launcher.decorators import single_experiment, single_experiment_wandb
from experiment_launcher.launcher import add_launcher_base_args, get_experiment_default_params
# This decorator is not mandatory.
# It creates results_dir as results_dir/seed, and saves the experiment arguments into a file.
@single_experiment
@single_experiment_wandb
def experiment(
#######################################
env: str = 'env', # You need to specify the argument type if you use the automatic parser.
env_param: str = 'aa',
a: int = 1,
boolean_param: bool = True,
some_default_param: str = 'b',
seed: int = 0, # This argument is mandatory
results_dir: str = '/tmp' # This argument is mandatory
#######################################
# WandB
wandb_enabled: bool = False,
wandb_entity: str = 'jaccarvalho',
wandb_project: str = 'test_experiment_launcher',
wandb_group: str = 'test_group',
#######################################
# MANDATORY
seed: int = 0,
results_dir: str = 'logs'
):
# EXPERIMENT
filename = os.path.join(results_dir, 'log_' + str(seed) + '.txt')
......@@ -29,6 +41,8 @@ def experiment(
file.write('Some logs in a log file.\n')
file.write(out_str)
wandb.log({'seed': seed})
# You can specify your own parser, or use the experiment_launcher parser.
def parse_args():
......
import datetime
import os
from functools import wraps
from functools import wraps, partial
from experiment_launcher.utils import save_args
from experiment_launcher.utils import save_args, fix_random_seed, start_wandb
def single_experiment(exp_func):
def pseudo_single_experiment(exp_func, save_args_yaml=False, use_wandb=False):
@wraps(exp_func)
def wrapper(*args, **kwargs):
# Make results directory
......@@ -12,10 +14,30 @@ def single_experiment(exp_func):
results_dir = os.path.join(kwargs['results_dir'], str(kwargs['seed']))
os.makedirs(results_dir, exist_ok=True)
kwargs['results_dir'] = results_dir
# Save arguments
save_args(results_dir, kwargs, git_repo_path='./')
save_args(results_dir, kwargs, git_repo_path='./', save_args_yaml=save_args_yaml)
# Fix seed
fix_random_seed(kwargs['seed'])
# Start WandB
wandb_run = None
if use_wandb:
wandb_run = start_wandb(**kwargs)
# Run the experiment
exp_func(*args, **kwargs)
# Clean up
if use_wandb:
wandb_run.finish()
return wrapper
single_experiment = partial(pseudo_single_experiment)
single_experiment_yaml = partial(pseudo_single_experiment, save_args_yaml=True)
single_experiment_wandb = partial(pseudo_single_experiment, save_args_yaml=True, use_wandb=True)
import datetime
import json
import os
import socket
import wandb
import yaml
import git
from git import InvalidGitRepositoryError
import random
try:
import numpy as np
import torch
except ImportError:
pass
def save_args(results_dir, args, git_repo_path=None, seed=None):
def save_args(results_dir, args, git_repo_path=None, seed=None, save_args_yaml=False):
try:
repo = git.Repo(git_repo_path, search_parent_directories=True)
args['git_hash'] = repo.head.object.hexsha
......@@ -16,14 +26,14 @@ def save_args(results_dir, args, git_repo_path=None, seed=None):
args['git_hash'] = ''
args['git_url'] = ''
filename = 'args.json' if seed is None else f'args-{seed}.json'
if save_args_yaml:
filename = 'args.yaml' if seed is None else f'args-{seed}.yaml'
with open(os.path.join(results_dir, filename), 'w') as f:
yaml.dump(args, f, Dumper=yaml.Dumper)
# filename = 'args.json' if seed is None else f'args-{seed}.json'
# with open(os.path.join(results_dir, filename), 'w') as f:
# json.dump(args, f, indent=2)
else:
filename = 'args.json' if seed is None else f'args-{seed}.json'
with open(os.path.join(results_dir, filename), 'w') as f:
json.dump(args, f, indent=2)
del args['git_hash']
del args['git_url']
......@@ -32,3 +42,44 @@ def save_args(results_dir, args, git_repo_path=None, seed=None):
def bool_local_cluster():
hostname = socket.gethostname()
return False if hostname == 'mn01' or 'logc' in hostname else True
def start_wandb(
wandb_enabled=False,
wandb_entity='experiment_launcher',
wandb_project='test_experiment_launcher',
wandb_group=None,
wandb_run_name=None,
**kwargs
):
if not wandb_enabled:
return wandb.init(mode="disabled")
init = {
"entity": wandb_entity,
"project": wandb_project,
"group": wandb_group,
"name": wandb_run_name,
"reinit": True,
"notes": datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'),
}
return wandb.init(**init)
def fix_random_seed(seed):
random.seed(seed)
try:
np.random.seed(seed)
except NameError:
pass
try:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking
# torch.backends.cudnn.benchmark = False
except NameError:
pass
......@@ -2,3 +2,4 @@ joblib
gitpython
numpy
pyyaml
wandb
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment