stable-diffusion-webui/modules/patches.py

65 lines
1.8 KiB
Python

from collections import defaultdict
def patch(key, obj, field, replacement):
"""Replaces a function in a module or a class.
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
replacement: the new function
Returns:
the original function
"""
patch_key = (obj, field)
if patch_key in originals[key]:
raise RuntimeError(f"patch for {field} is already applied")
original_func = getattr(obj, field)
originals[key][patch_key] = original_func
setattr(obj, field, replacement)
return original_func
def undo(key, obj, field):
"""Undoes the peplacement by the patch().
If the function is not replaced, raises an exception.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
Returns:
Always None
"""
patch_key = (obj, field)
if patch_key not in originals[key]:
raise RuntimeError(f"there is no patch for {field} to undo")
original_func = originals[key].pop(patch_key)
setattr(obj, field, original_func)
return None
def original(key, obj, field):
"""Returns the original function for the patch created by the patch() function"""
patch_key = (obj, field)
return originals[key].get(patch_key, None)
originals = defaultdict(dict)