diff --git a/modules/processing.py b/modules/processing.py index e777a9651..65ae4846f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -19,20 +19,11 @@ import modules.face_restoration import modules.images as images import modules.styles -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor - -# load safety model -safety_model_id = "CompVis/stable-diffusion-safety-checker" -safety_feature_extractor = None -safety_checker = None - # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 opt_f = 8 - class StableDiffusionProcessing: def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): self.sd_model = sd_model @@ -154,28 +145,6 @@ def fix_seed(p): p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed -def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - -# check and replace nsfw content -def check_safety(x_image): - global safety_feature_extractor, safety_checker - if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) - return x_checked_image, has_nsfw_concept - - def process_images(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" @@ -279,9 +248,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) if opts.filter_nsfw: - x_samples_ddim_numpy = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) - x_samples_ddim = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + import modules.safety as safety + x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) diff --git a/modules/safety.py b/modules/safety.py new file mode 100644 index 000000000..cff4b2783 --- /dev/null +++ b/modules/safety.py @@ -0,0 +1,42 @@ +import torch +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +from PIL import Image + +import modules.shared as shared + +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = None +safety_checker = None + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + +# check and replace nsfw content +def check_safety(x_image): + global safety_feature_extractor, safety_checker + + if safety_feature_extractor is None: + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + + return x_checked_image, has_nsfw_concept + + +def censor_batch(x): + x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) + x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + return x