#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 17 13:24:03 2022
@author: chamaeleontis
Contains utility functions and decorators to use in other script.
"""
#%% Importing libraries
import pickle
import time
import os
from functools import wraps
import inspect
import logging
from colorlog import ColoredFormatter
from functools import wraps
import inspect
import logging
logger = logging.getLogger(__name__)
[docs]
def addLoggingLevel(levelName, levelNum, methodName=None):
"""
Comprehensively adds a new logging level to the `logging` module and the
currently configured logging class.
`levelName` becomes an attribute of the `logging` module with the value
`levelNum`. `methodName` becomes a convenience method for both `logging`
itself and the class returned by `logging.getLoggerClass()` (usually just
`logging.Logger`). If `methodName` is not specified, `levelName.lower()` is
used.
To avoid accidental clobberings of existing attributes, this method will
raise an `AttributeError` if the level name is already an attribute of the
`logging` module or if the method name is already present
Example
-------
>>> addLoggingLevel('TRACE', logging.DEBUG - 5)
>>> logging.getLogger(__name__).setLevel("TRACE")
>>> logging.getLogger(__name__).trace('that worked')
>>> logging.trace('so did this')
>>> logging.TRACE
"""
if not methodName:
methodName = levelName.lower()
if hasattr(logging, levelName):
pass
# raise AttributeError('{} already defined in logging module'.format(levelName))
if hasattr(logging, methodName):
pass
# raise AttributeError('{} already defined in logging module'.format(methodName))
if hasattr(logging.getLoggerClass(), methodName):
pass
# raise AttributeError('{} already defined in logger class'.format(methodName))
# This method was inspired by the answers to Stack Overflow post
# http://stackoverflow.com/q/2183233/2988730, especially
# http://stackoverflow.com/a/13638084/2988730
def logForLevel(self, message, *args, **kwargs):
if self.isEnabledFor(levelNum):
self._log(levelNum, message, args, **kwargs)
def logToRoot(message, *args, **kwargs):
logging.log(levelNum, message, *args, **kwargs)
logging.addLevelName(levelNum, levelName)
setattr(logging, levelName, levelNum)
setattr(logging.getLoggerClass(), methodName, logForLevel)
setattr(logging, methodName, logToRoot)
[docs]
def logger_default(logger_update: logging.Logger) -> logging.Logger:
"""
Updates the Logger to include print level and different set of colors for each level.
Parameters
----------
logger_update : logging.Logger
Logger to update
Returns
-------
logging.Logger
Outputed logger
"""
addLoggingLevel('PRINT', 25, methodName=None)
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
from colorlog import ColoredFormatter
logging.root.setLevel(LOG_LEVEL)
formatter = ColoredFormatter(LOGFORMAT)
formatter.log_colors['WARNING'] = "cyan"
formatter.log_colors['PRINT'] = "yellow"
stream = logging.StreamHandler()
stream.setLevel(LOG_LEVEL)
stream.setFormatter(formatter)
logger_update.setLevel(LOG_LEVEL)
logger_update.addHandler(stream)
return logger_update
#%% time_function
[docs]
def time_function(func):
'''
Decorator function for timing the function
Parameters
----------
func : function
Function to decorate.
Returns
-------
wrapper : TYPE
DESCRIPTION.
'''
@wraps(func)
def wrapper(*args,**kwargs):
'''
Time the wrapped function and print it.
Parameters
----------
*args : tuple
Arguments.
**kwargs : tuple
Key arguments.
Returns
-------
output : output
Output of the wrapped function.
'''
t1 = time.time()
output = func(*args,**kwargs)
t2 = time.time()-t1
logger.info(f'{func.__name__} ran in {t2} seconds')
return output
return wrapper
#%% save_and_load
[docs]
def save_and_load(func):
'''
Decorator function for saving and loading the function
Parameters
----------
func : function
Function to decorate.
Returns
-------
wrapper : function
Decorated function.
'''
signature = inspect.signature(func)
default_kwargs = {
kw: value.default for kw, value in signature.parameters.items() if value.default != inspect.Signature.empty
}
@wraps(func)
def wrapper(*args, **kwargs,):
'''
Loads and saves the output of the function into pickle file. The functionality is as follows
force_skip = True:
Skip the function entirely
force_load = True and force_skip = False:
Try to load the function instead of running
If failed, run the function instead
force_load = False and force_skip = False:
Run the function normally
Parameters
----------
*args : tuple
Arguments.
**kwargs : tuple
Key arguments.
Returns
-------
output : output
Output of the wrapped function.
'''
if not(os.path.exists(os.getcwd() + '/saved_data/')):
os.mkdir(os.getcwd() + '/saved_data/')
try:
file = os.getcwd() + '/saved_data/' + kwargs['pkl_name']
except:
file = os.getcwd() + '/saved_data/' + func.__name__ + '.pkl'
kwargs = default_kwargs | kwargs
try: # Exception so force_skip does not need to be defined
kwargs['force_skip']
except:
kwargs['force_skip'] = False
try: # Exception so force_load does not need to be defined
kwargs['force_load']
except:
kwargs['force_load'] = False
if kwargs['force_skip'] == True:
logger.info('Currently skipping working on: '+f'{func.__name__}')
return None
elif kwargs['force_load'] == True:
try:
with open(file, 'rb') as input_file:
output = pickle.load(input_file)
return output
except:
logger.info('Error opening input file: %s'%file)
pass
output = func(*args,**kwargs)
with open(file, 'wb') as output_file:
logger.info('Saved progress in %s'%file)
pickle.dump(output, output_file)
return output
return wrapper
#%% progress_tracker
[docs]
def progress_tracker(func):
'''
Decorator function for progress tracking of the function
Parameters
----------
func : function
Function to decorate.
Returns
-------
wrapper : function
Decorated function.
'''
@wraps(func)
def wrapper(*args,**kwargs):
'''
Prints the name of the function currently running
Parameters
----------
*args : tuple
Arguments.
**kwargs : tuple
Key arguments.
Returns
-------
output : output
Output of the wrapped function.
'''
logger.info('Currently working on: '+f'{func.__name__}')
output = func(*args,**kwargs)
return output
return wrapper
#%% disable_func
[docs]
def disable_func(func):
'''
This decorator disables a function. Use this to check breakability of code by disabling (for removing legacy code pieces).
Parameters
----------
func : function
Function to decorate.
Returns
-------
wrapper : function
Decorated function.
'''
@wraps(func)
def wrapper(*args,**kwargs):
logger.error('Function: '+f'{func.__name__}'+' has been called while disabled!')
pass
return wrapper
#%% skip_function
[docs]
def skip_function(func):
'''
This decorator disables a function based on a provided keyword. Use this to of code by disabling for time saving.
Parameters
----------
func : function
Function to decorate.
Returns
-------
wrapper : function
Decorated function.
'''
signature = inspect.signature(func)
default_kwargs = {
kw: value.default for kw, value in signature.parameters.items() if value.default != inspect.Signature.empty
}
@wraps(func)
def wrapper(*args,**kwargs):
kwargs = default_kwargs | kwargs
if kwargs['force_skip']:
logger.info('Currently skipping working on: '+f'{func.__name__}')
pass
else:
output = func(*args,**kwargs)
return output
return wrapper
#%% todo_function
[docs]
def todo_function(func):
'''
This decorator warns the use to TODO in function.
Parameters
----------
func : function
Function to decorate.
Returns
-------
wrapper : function
Decorated function.
'''
@wraps(func)
def wrapper(*args,**kwargs):
logger.warning('Function: '+f'{func.__name__}'+ ' requires attention. Check the documentation for more information (above the def command)')
output = func(*args,**kwargs)
return output
return wrapper
#%% save_figure
#%%
[docs]
def plot_different_setting(style):
def decorator(func):
def wrapper(*args, **kwargs):
with sns.axes_style(style):
output = func(*args,**kwargs)
return output
return wrapper
return decorator
#%%
if __name__ == '__main__':
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
style_dark = "dark"
style_white = "white"
@plot_different_setting(style_white)
def sinplot(n=10, flip=1):
x = np.linspace(0, 14, 100)
for i in range(1, n + 1):
plt.plot(x, np.sin(x + i * .5) * (n + 2 - i) * flip)
with sns.plotting_context("poster"):
sinplot()
#%% save_progress
[docs]
@disable_func
def save_progress(func,*args,file=None,force_calculate=False,no_save = False):
'''
Function to save progress of other functions
Input:
func ; function - function for which we want to save progress
file ; string - pickle file (*.pkl) where to save/load result
force_calculate = False ; True/False - Force to recalculate function
no_save = False ; True/False - Force not to save data (in case of big files
Output:
result ; *results - resulting arguments that we would normally receive from func(*args)
'''
if not(force_calculate): # Recalculate data?
try:
with open(file, 'rb') as input_file:
print('Result loaded in: %s'%(file))
result = pickle.load(input_file)
except: # Need to recalculate as it wasn't found
print('Recalculating data as saved pickle was not found in: %s'%file)
force_calculate=True
if force_calculate: # Recalculate
result = func(*args)
if not(no_save): # Save result?
with open(file, 'wb') as output_file:
print('Saved progress in %s'%file)
pickle.dump(result, output_file)
return result