mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge branch 'master' into hypernetwork-training
This commit is contained in:
commit
5de806184f
28
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
28
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
|
||||||
|
|
||||||
|
If you have a large change, pay special attention to this paragraph:
|
||||||
|
|
||||||
|
> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
|
||||||
|
|
||||||
|
Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
|
||||||
|
|
||||||
|
**Describe what this pull request is trying to achieve.**
|
||||||
|
|
||||||
|
A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
|
||||||
|
|
||||||
|
**Additional notes and description of your changes**
|
||||||
|
|
||||||
|
More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
|
||||||
|
|
||||||
|
**Environment this was tested in**
|
||||||
|
|
||||||
|
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
|
||||||
|
- OS: [e.g. Windows, Linux]
|
||||||
|
- Browser [e.g. chrome, safari]
|
||||||
|
- Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
|
||||||
|
|
||||||
|
**Screenshots or videos of your changes**
|
||||||
|
|
||||||
|
If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
|
||||||
|
|
||||||
|
This is **required** for anything that touches the user interface.
|
@ -16,7 +16,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- Attention, specify parts of text that the model should pay more attention to
|
- Attention, specify parts of text that the model should pay more attention to
|
||||||
- a man in a ((tuxedo)) - will pay more attention to tuxedo
|
- a man in a ((tuxedo)) - will pay more attention to tuxedo
|
||||||
- a man in a (tuxedo:1.21) - alternative syntax
|
- a man in a (tuxedo:1.21) - alternative syntax
|
||||||
- select text and press ctrl+up or ctrl+down to aduotmatically adjust attention to selected text
|
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
|
||||||
- Loopback, run img2img processing multiple times
|
- Loopback, run img2img processing multiple times
|
||||||
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
||||||
- Textual Inversion
|
- Textual Inversion
|
||||||
@ -65,6 +65,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
|
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
|
||||||
- separate prompts using uppercase `AND`
|
- separate prompts using uppercase `AND`
|
||||||
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
|
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
|
||||||
|
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
||||||
|
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
@ -122,4 +124,5 @@ The documentation was moved from this README over to the project's [wiki](https:
|
|||||||
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
||||||
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
|
- DeepDanbooru - interrogator for anime diffusors https://github.com/KichangKim/DeepDanbooru
|
||||||
- (You)
|
- (You)
|
||||||
|
168
javascript/contextMenus.js
Normal file
168
javascript/contextMenus.js
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
|
||||||
|
contextMenuInit = function(){
|
||||||
|
let eventListenerApplied=false;
|
||||||
|
let menuSpecs = new Map();
|
||||||
|
|
||||||
|
const uid = function(){
|
||||||
|
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
function showContextMenu(event,element,menuEntries){
|
||||||
|
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
||||||
|
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
||||||
|
|
||||||
|
let oldMenu = gradioApp().querySelector('#context-menu')
|
||||||
|
if(oldMenu){
|
||||||
|
oldMenu.remove()
|
||||||
|
}
|
||||||
|
|
||||||
|
let tabButton = gradioApp().querySelector('button')
|
||||||
|
let baseStyle = window.getComputedStyle(tabButton)
|
||||||
|
|
||||||
|
const contextMenu = document.createElement('nav')
|
||||||
|
contextMenu.id = "context-menu"
|
||||||
|
contextMenu.style.background = baseStyle.background
|
||||||
|
contextMenu.style.color = baseStyle.color
|
||||||
|
contextMenu.style.fontFamily = baseStyle.fontFamily
|
||||||
|
contextMenu.style.top = posy+'px'
|
||||||
|
contextMenu.style.left = posx+'px'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
const contextMenuList = document.createElement('ul')
|
||||||
|
contextMenuList.className = 'context-menu-items';
|
||||||
|
contextMenu.append(contextMenuList);
|
||||||
|
|
||||||
|
menuEntries.forEach(function(entry){
|
||||||
|
let contextMenuEntry = document.createElement('a')
|
||||||
|
contextMenuEntry.innerHTML = entry['name']
|
||||||
|
contextMenuEntry.addEventListener("click", function(e) {
|
||||||
|
entry['func']();
|
||||||
|
})
|
||||||
|
contextMenuList.append(contextMenuEntry);
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
gradioApp().getRootNode().appendChild(contextMenu)
|
||||||
|
|
||||||
|
let menuWidth = contextMenu.offsetWidth + 4;
|
||||||
|
let menuHeight = contextMenu.offsetHeight + 4;
|
||||||
|
|
||||||
|
let windowWidth = window.innerWidth;
|
||||||
|
let windowHeight = window.innerHeight;
|
||||||
|
|
||||||
|
if ( (windowWidth - posx) < menuWidth ) {
|
||||||
|
contextMenu.style.left = windowWidth - menuWidth + "px";
|
||||||
|
}
|
||||||
|
|
||||||
|
if ( (windowHeight - posy) < menuHeight ) {
|
||||||
|
contextMenu.style.top = windowHeight - menuHeight + "px";
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
|
||||||
|
|
||||||
|
currentItems = menuSpecs.get(targetEmementSelector)
|
||||||
|
|
||||||
|
if(!currentItems){
|
||||||
|
currentItems = []
|
||||||
|
menuSpecs.set(targetEmementSelector,currentItems);
|
||||||
|
}
|
||||||
|
let newItem = {'id':targetEmementSelector+'_'+uid(),
|
||||||
|
'name':entryName,
|
||||||
|
'func':entryFunction,
|
||||||
|
'isNew':true}
|
||||||
|
|
||||||
|
currentItems.push(newItem)
|
||||||
|
return newItem['id']
|
||||||
|
}
|
||||||
|
|
||||||
|
function removeContextMenuOption(uid){
|
||||||
|
menuSpecs.forEach(function(v,k) {
|
||||||
|
let index = -1
|
||||||
|
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
||||||
|
if(index>=0){
|
||||||
|
v.splice(index, 1);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function addContextMenuEventListener(){
|
||||||
|
if(eventListenerApplied){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
gradioApp().addEventListener("click", function(e) {
|
||||||
|
let source = e.composedPath()[0]
|
||||||
|
if(source.id && source.indexOf('check_progress')>-1){
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let oldMenu = gradioApp().querySelector('#context-menu')
|
||||||
|
if(oldMenu){
|
||||||
|
oldMenu.remove()
|
||||||
|
}
|
||||||
|
});
|
||||||
|
gradioApp().addEventListener("contextmenu", function(e) {
|
||||||
|
let oldMenu = gradioApp().querySelector('#context-menu')
|
||||||
|
if(oldMenu){
|
||||||
|
oldMenu.remove()
|
||||||
|
}
|
||||||
|
menuSpecs.forEach(function(v,k) {
|
||||||
|
if(e.composedPath()[0].matches(k)){
|
||||||
|
showContextMenu(e,e.composedPath()[0],v)
|
||||||
|
e.preventDefault()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
eventListenerApplied=true
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
||||||
|
}
|
||||||
|
|
||||||
|
initResponse = contextMenuInit()
|
||||||
|
appendContextMenuOption = initResponse[0]
|
||||||
|
removeContextMenuOption = initResponse[1]
|
||||||
|
addContextMenuEventListener = initResponse[2]
|
||||||
|
|
||||||
|
|
||||||
|
//Start example Context Menu Items
|
||||||
|
generateOnRepeatId = appendContextMenuOption('#txt2img_generate','Generate forever',function(){
|
||||||
|
let genbutton = gradioApp().querySelector('#txt2img_generate');
|
||||||
|
let interruptbutton = gradioApp().querySelector('#txt2img_interrupt');
|
||||||
|
if(!interruptbutton.offsetParent){
|
||||||
|
genbutton.click();
|
||||||
|
}
|
||||||
|
clearInterval(window.generateOnRepeatInterval)
|
||||||
|
window.generateOnRepeatInterval = setInterval(function(){
|
||||||
|
if(!interruptbutton.offsetParent){
|
||||||
|
genbutton.click();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
500)}
|
||||||
|
)
|
||||||
|
|
||||||
|
cancelGenerateForever = function(){
|
||||||
|
clearInterval(window.generateOnRepeatInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
||||||
|
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
||||||
|
|
||||||
|
|
||||||
|
appendContextMenuOption('#roll','Roll three',
|
||||||
|
function(){
|
||||||
|
let rollbutton = gradioApp().querySelector('#roll');
|
||||||
|
setTimeout(function(){rollbutton.click()},100)
|
||||||
|
setTimeout(function(){rollbutton.click()},200)
|
||||||
|
setTimeout(function(){rollbutton.click()},300)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
//End example Context Menu Items
|
||||||
|
|
||||||
|
onUiUpdate(function(){
|
||||||
|
addContextMenuEventListener()
|
||||||
|
});
|
@ -1,5 +1,5 @@
|
|||||||
addEventListener('keydown', (event) => {
|
addEventListener('keydown', (event) => {
|
||||||
let target = event.originalTarget;
|
let target = event.originalTarget || event.composedPath()[0];
|
||||||
if (!target.hasAttribute("placeholder")) return;
|
if (!target.hasAttribute("placeholder")) return;
|
||||||
if (!target.placeholder.toLowerCase().includes("prompt")) return;
|
if (!target.placeholder.toLowerCase().includes("prompt")) return;
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ titles = {
|
|||||||
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
|
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
|
||||||
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
|
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
|
||||||
|
|
||||||
|
"Skip": "Stop processing current image and continue processing.",
|
||||||
"Interrupt": "Stop processing images and return any results accumulated so far.",
|
"Interrupt": "Stop processing images and return any results accumulated so far.",
|
||||||
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
|
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
|
||||||
|
|
||||||
@ -78,6 +79,8 @@ titles = {
|
|||||||
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
|
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
|
||||||
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
|
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
|
||||||
|
|
||||||
|
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
||||||
|
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,72 +1,97 @@
|
|||||||
// A full size 'lightbox' preview modal shown when left clicking on gallery previews
|
// A full size 'lightbox' preview modal shown when left clicking on gallery previews
|
||||||
|
|
||||||
function closeModal() {
|
function closeModal() {
|
||||||
gradioApp().getElementById("lightboxModal").style.display = "none";
|
gradioApp().getElementById("lightboxModal").style.display = "none";
|
||||||
}
|
}
|
||||||
|
|
||||||
function showModal(event) {
|
function showModal(event) {
|
||||||
const source = event.target || event.srcElement;
|
const source = event.target || event.srcElement;
|
||||||
const modalImage = gradioApp().getElementById("modalImage")
|
const modalImage = gradioApp().getElementById("modalImage")
|
||||||
const lb = gradioApp().getElementById("lightboxModal")
|
const lb = gradioApp().getElementById("lightboxModal")
|
||||||
modalImage.src = source.src
|
modalImage.src = source.src
|
||||||
if (modalImage.style.display === 'none') {
|
if (modalImage.style.display === 'none') {
|
||||||
lb.style.setProperty('background-image', 'url(' + source.src + ')');
|
lb.style.setProperty('background-image', 'url(' + source.src + ')');
|
||||||
}
|
}
|
||||||
lb.style.display = "block";
|
lb.style.display = "block";
|
||||||
lb.focus()
|
lb.focus()
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
}
|
}
|
||||||
|
|
||||||
function negmod(n, m) {
|
function negmod(n, m) {
|
||||||
return ((n % m) + m) % m;
|
return ((n % m) + m) % m;
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalImageSwitch(offset){
|
function updateOnBackgroundChange() {
|
||||||
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
|
const modalImage = gradioApp().getElementById("modalImage")
|
||||||
var galleryButtons = []
|
if (modalImage && modalImage.offsetParent) {
|
||||||
allgalleryButtons.forEach(function(elem){
|
let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
|
||||||
if(elem.parentElement.offsetParent){
|
let currentButton = null
|
||||||
galleryButtons.push(elem);
|
allcurrentButtons.forEach(function(elem) {
|
||||||
|
if (elem.parentElement.offsetParent) {
|
||||||
|
currentButton = elem;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if (modalImage.src != currentButton.children[0].src) {
|
||||||
|
modalImage.src = currentButton.children[0].src;
|
||||||
|
if (modalImage.style.display === 'none') {
|
||||||
|
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
if(galleryButtons.length>1){
|
function modalImageSwitch(offset) {
|
||||||
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
|
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
|
||||||
var currentButton = null
|
var galleryButtons = []
|
||||||
allcurrentButtons.forEach(function(elem){
|
allgalleryButtons.forEach(function(elem) {
|
||||||
if(elem.parentElement.offsetParent){
|
if (elem.parentElement.offsetParent) {
|
||||||
currentButton = elem;
|
galleryButtons.push(elem);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
var result = -1
|
if (galleryButtons.length > 1) {
|
||||||
galleryButtons.forEach(function(v, i){ if(v==currentButton) { result = i } })
|
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
|
||||||
|
var currentButton = null
|
||||||
|
allcurrentButtons.forEach(function(elem) {
|
||||||
|
if (elem.parentElement.offsetParent) {
|
||||||
|
currentButton = elem;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
if(result != -1){
|
var result = -1
|
||||||
nextButton = galleryButtons[negmod((result+offset),galleryButtons.length)]
|
galleryButtons.forEach(function(v, i) {
|
||||||
nextButton.click()
|
if (v == currentButton) {
|
||||||
const modalImage = gradioApp().getElementById("modalImage");
|
result = i
|
||||||
const modal = gradioApp().getElementById("lightboxModal");
|
}
|
||||||
modalImage.src = nextButton.children[0].src;
|
})
|
||||||
if (modalImage.style.display === 'none') {
|
|
||||||
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
if (result != -1) {
|
||||||
|
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
||||||
|
nextButton.click()
|
||||||
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
|
modalImage.src = nextButton.children[0].src;
|
||||||
|
if (modalImage.style.display === 'none') {
|
||||||
|
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
||||||
|
}
|
||||||
|
setTimeout(function() {
|
||||||
|
modal.focus()
|
||||||
|
}, 10)
|
||||||
}
|
}
|
||||||
setTimeout( function(){modal.focus()},10)
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalNextImage(event){
|
function modalNextImage(event) {
|
||||||
modalImageSwitch(1)
|
modalImageSwitch(1)
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalPrevImage(event){
|
function modalPrevImage(event) {
|
||||||
modalImageSwitch(-1)
|
modalImageSwitch(-1)
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalKeyHandler(event){
|
function modalKeyHandler(event) {
|
||||||
switch (event.key) {
|
switch (event.key) {
|
||||||
case "ArrowLeft":
|
case "ArrowLeft":
|
||||||
modalPrevImage(event)
|
modalPrevImage(event)
|
||||||
@ -80,21 +105,22 @@ function modalKeyHandler(event){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function showGalleryImage(){
|
function showGalleryImage() {
|
||||||
setTimeout(function() {
|
setTimeout(function() {
|
||||||
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
|
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
|
||||||
|
|
||||||
if(fullImg_preview != null){
|
if (fullImg_preview != null) {
|
||||||
fullImg_preview.forEach(function function_name(e) {
|
fullImg_preview.forEach(function function_name(e) {
|
||||||
|
if (e.dataset.modded)
|
||||||
|
return;
|
||||||
|
e.dataset.modded = true;
|
||||||
if(e && e.parentElement.tagName == 'DIV'){
|
if(e && e.parentElement.tagName == 'DIV'){
|
||||||
|
|
||||||
e.style.cursor='pointer'
|
e.style.cursor='pointer'
|
||||||
|
|
||||||
e.addEventListener('click', function (evt) {
|
e.addEventListener('click', function (evt) {
|
||||||
if(!opts.js_modal_lightbox) return;
|
if(!opts.js_modal_lightbox) return;
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initialy_zoomed)
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||||
showModal(evt)
|
showModal(evt)
|
||||||
},true);
|
}, true);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -102,21 +128,21 @@ function showGalleryImage(){
|
|||||||
}, 100);
|
}, 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomSet(modalImage, enable){
|
function modalZoomSet(modalImage, enable) {
|
||||||
if( enable ){
|
if (enable) {
|
||||||
modalImage.classList.add('modalImageFullscreen');
|
modalImage.classList.add('modalImageFullscreen');
|
||||||
} else{
|
} else {
|
||||||
modalImage.classList.remove('modalImageFullscreen');
|
modalImage.classList.remove('modalImageFullscreen');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomToggle(event){
|
function modalZoomToggle(event) {
|
||||||
modalImage = gradioApp().getElementById("modalImage");
|
modalImage = gradioApp().getElementById("modalImage");
|
||||||
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalTileImageToggle(event){
|
function modalTileImageToggle(event) {
|
||||||
const modalImage = gradioApp().getElementById("modalImage");
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
const modal = gradioApp().getElementById("lightboxModal");
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
const isTiling = modalImage.style.display === 'none';
|
const isTiling = modalImage.style.display === 'none';
|
||||||
@ -131,17 +157,18 @@ function modalTileImageToggle(event){
|
|||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
}
|
}
|
||||||
|
|
||||||
function galleryImageHandler(e){
|
function galleryImageHandler(e) {
|
||||||
if(e && e.parentElement.tagName == 'BUTTON'){
|
if (e && e.parentElement.tagName == 'BUTTON') {
|
||||||
e.onclick = showGalleryImage;
|
e.onclick = showGalleryImage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function() {
|
||||||
fullImg_preview = gradioApp().querySelectorAll('img.w-full')
|
fullImg_preview = gradioApp().querySelectorAll('img.w-full')
|
||||||
if(fullImg_preview != null){
|
if (fullImg_preview != null) {
|
||||||
fullImg_preview.forEach(galleryImageHandler);
|
fullImg_preview.forEach(galleryImageHandler);
|
||||||
}
|
}
|
||||||
|
updateOnBackgroundChange();
|
||||||
})
|
})
|
||||||
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
document.addEventListener("DOMContentLoaded", function() {
|
||||||
@ -149,7 +176,7 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
const modal = document.createElement('div')
|
const modal = document.createElement('div')
|
||||||
modal.onclick = closeModal;
|
modal.onclick = closeModal;
|
||||||
modal.id = "lightboxModal";
|
modal.id = "lightboxModal";
|
||||||
modal.tabIndex=0
|
modal.tabIndex = 0
|
||||||
modal.addEventListener('keydown', modalKeyHandler, true)
|
modal.addEventListener('keydown', modalKeyHandler, true)
|
||||||
|
|
||||||
const modalControls = document.createElement('div')
|
const modalControls = document.createElement('div')
|
||||||
@ -180,23 +207,23 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
const modalImage = document.createElement('img')
|
const modalImage = document.createElement('img')
|
||||||
modalImage.id = 'modalImage';
|
modalImage.id = 'modalImage';
|
||||||
modalImage.onclick = closeModal;
|
modalImage.onclick = closeModal;
|
||||||
modalImage.tabIndex=0
|
modalImage.tabIndex = 0
|
||||||
modalImage.addEventListener('keydown', modalKeyHandler, true)
|
modalImage.addEventListener('keydown', modalKeyHandler, true)
|
||||||
modal.appendChild(modalImage)
|
modal.appendChild(modalImage)
|
||||||
|
|
||||||
const modalPrev = document.createElement('a')
|
const modalPrev = document.createElement('a')
|
||||||
modalPrev.className = 'modalPrev';
|
modalPrev.className = 'modalPrev';
|
||||||
modalPrev.innerHTML = '❮'
|
modalPrev.innerHTML = '❮'
|
||||||
modalPrev.tabIndex=0
|
modalPrev.tabIndex = 0
|
||||||
modalPrev.addEventListener('click',modalPrevImage,true);
|
modalPrev.addEventListener('click', modalPrevImage, true);
|
||||||
modalPrev.addEventListener('keydown', modalKeyHandler, true)
|
modalPrev.addEventListener('keydown', modalKeyHandler, true)
|
||||||
modal.appendChild(modalPrev)
|
modal.appendChild(modalPrev)
|
||||||
|
|
||||||
const modalNext = document.createElement('a')
|
const modalNext = document.createElement('a')
|
||||||
modalNext.className = 'modalNext';
|
modalNext.className = 'modalNext';
|
||||||
modalNext.innerHTML = '❯'
|
modalNext.innerHTML = '❯'
|
||||||
modalNext.tabIndex=0
|
modalNext.tabIndex = 0
|
||||||
modalNext.addEventListener('click',modalNextImage,true);
|
modalNext.addEventListener('click', modalNextImage, true);
|
||||||
modalNext.addEventListener('keydown', modalKeyHandler, true)
|
modalNext.addEventListener('keydown', modalKeyHandler, true)
|
||||||
|
|
||||||
modal.appendChild(modalNext)
|
modal.appendChild(modalNext)
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
// code related to showing and updating progressbar shown as the image is being made
|
// code related to showing and updating progressbar shown as the image is being made
|
||||||
global_progressbars = {}
|
global_progressbars = {}
|
||||||
|
|
||||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){
|
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
||||||
var progressbar = gradioApp().getElementById(id_progressbar)
|
var progressbar = gradioApp().getElementById(id_progressbar)
|
||||||
|
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
||||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||||
|
|
||||||
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
||||||
@ -32,30 +33,37 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
|
|||||||
|
|
||||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||||
if(!progressDiv){
|
if(!progressDiv){
|
||||||
|
if (skip) {
|
||||||
|
skip.style.display = "none"
|
||||||
|
}
|
||||||
interrupt.style.display = "none"
|
interrupt.style.display = "none"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
window.setTimeout(function(){ requestMoreProgress(id_part, id_progressbar_span, id_interrupt) }, 500)
|
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
|
||||||
});
|
});
|
||||||
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||||
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
||||||
})
|
})
|
||||||
|
|
||||||
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
|
function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
|
||||||
btn = gradioApp().getElementById(id_part+"_check_progress");
|
btn = gradioApp().getElementById(id_part+"_check_progress");
|
||||||
if(btn==null) return;
|
if(btn==null) return;
|
||||||
|
|
||||||
btn.click();
|
btn.click();
|
||||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||||
|
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
||||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||||
if(progressDiv && interrupt){
|
if(progressDiv && interrupt){
|
||||||
|
if (skip) {
|
||||||
|
skip.style.display = "block"
|
||||||
|
}
|
||||||
interrupt.style.display = "block"
|
interrupt.style.display = "block"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
137
launch.py
137
launch.py
@ -4,39 +4,17 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import shlex
|
import shlex
|
||||||
|
import platform
|
||||||
|
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
dir_tmp = "tmp"
|
|
||||||
|
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
|
||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
|
||||||
|
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
|
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
|
||||||
|
|
||||||
args = shlex.split(commandline_args)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_arg(args, name):
|
def extract_arg(args, name):
|
||||||
return [x for x in args if x != name], name in args
|
return [x for x in args if x != name], name in args
|
||||||
|
|
||||||
|
|
||||||
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
|
|
||||||
|
|
||||||
|
|
||||||
def repo_dir(name):
|
|
||||||
return os.path.join(dir_repos, name)
|
|
||||||
|
|
||||||
|
|
||||||
def run(command, desc=None, errdesc=None):
|
def run(command, desc=None, errdesc=None):
|
||||||
if desc is not None:
|
if desc is not None:
|
||||||
print(desc)
|
print(desc)
|
||||||
@ -56,23 +34,11 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st
|
|||||||
return result.stdout.decode(encoding="utf8", errors="ignore")
|
return result.stdout.decode(encoding="utf8", errors="ignore")
|
||||||
|
|
||||||
|
|
||||||
def run_python(code, desc=None, errdesc=None):
|
|
||||||
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
|
||||||
|
|
||||||
|
|
||||||
def run_pip(args, desc=None):
|
|
||||||
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
|
||||||
|
|
||||||
|
|
||||||
def check_run(command):
|
def check_run(command):
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code):
|
|
||||||
return check_run(f'"{python}" -c "{code}"')
|
|
||||||
|
|
||||||
|
|
||||||
def is_installed(package):
|
def is_installed(package):
|
||||||
try:
|
try:
|
||||||
spec = importlib.util.find_spec(package)
|
spec = importlib.util.find_spec(package)
|
||||||
@ -82,6 +48,22 @@ def is_installed(package):
|
|||||||
return spec is not None
|
return spec is not None
|
||||||
|
|
||||||
|
|
||||||
|
def repo_dir(name):
|
||||||
|
return os.path.join(dir_repos, name)
|
||||||
|
|
||||||
|
|
||||||
|
def run_python(code, desc=None, errdesc=None):
|
||||||
|
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
||||||
|
|
||||||
|
|
||||||
|
def run_pip(args, desc=None):
|
||||||
|
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
||||||
|
|
||||||
|
|
||||||
|
def check_run_python(code):
|
||||||
|
return check_run(f'"{python}" -c "{code}"')
|
||||||
|
|
||||||
|
|
||||||
def git_clone(url, dir, name, commithash=None):
|
def git_clone(url, dir, name, commithash=None):
|
||||||
# TODO clone into temporary dir and move if successful
|
# TODO clone into temporary dir and move if successful
|
||||||
|
|
||||||
@ -103,50 +85,81 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
|
|
||||||
|
|
||||||
try:
|
def prepare_enviroment():
|
||||||
commit = run(f"{git} rev-parse HEAD").strip()
|
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
||||||
except Exception:
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
commit = "<none>"
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
|
|
||||||
print(f"Python {sys.version}")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||||
print(f"Commit hash: {commit}")
|
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||||
|
|
||||||
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
||||||
|
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||||
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
|
||||||
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
if not is_installed("torch") or not is_installed("torchvision"):
|
args = shlex.split(commandline_args)
|
||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
|
||||||
|
|
||||||
if not skip_torch_cuda_test:
|
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
|
||||||
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
xformers = '--xformers' in args
|
||||||
|
deepdanbooru = '--deepdanbooru' in args
|
||||||
|
|
||||||
if not is_installed("gfpgan"):
|
try:
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
commit = run(f"{git} rev-parse HEAD").strip()
|
||||||
|
except Exception:
|
||||||
|
commit = "<none>"
|
||||||
|
|
||||||
if not is_installed("clip"):
|
print(f"Python {sys.version}")
|
||||||
run_pip(f"install {clip_package}", "clip")
|
print(f"Commit hash: {commit}")
|
||||||
|
|
||||||
os.makedirs(dir_repos, exist_ok=True)
|
if not is_installed("torch") or not is_installed("torchvision"):
|
||||||
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
||||||
|
|
||||||
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
if not skip_torch_cuda_test:
|
||||||
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
||||||
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
|
||||||
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
|
||||||
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
|
||||||
|
|
||||||
if not is_installed("lpips"):
|
if not is_installed("gfpgan"):
|
||||||
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
|
|
||||||
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
if not is_installed("clip"):
|
||||||
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
|
||||||
sys.argv += args
|
if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"):
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
|
||||||
|
elif platform.system() == "Linux":
|
||||||
|
run_pip("install xformers", "xformers")
|
||||||
|
|
||||||
|
if not is_installed("deepdanbooru") and deepdanbooru:
|
||||||
|
run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
|
||||||
|
|
||||||
|
os.makedirs(dir_repos, exist_ok=True)
|
||||||
|
|
||||||
|
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
|
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
||||||
|
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
|
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
|
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|
||||||
|
if not is_installed("lpips"):
|
||||||
|
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
||||||
|
|
||||||
|
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
||||||
|
|
||||||
|
sys.argv += args
|
||||||
|
|
||||||
|
if "--exit" in args:
|
||||||
|
print("Exiting because of --exit argument")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
if "--exit" in args:
|
|
||||||
print("Exiting because of --exit argument")
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
def start_webui():
|
def start_webui():
|
||||||
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
|
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
|
||||||
import webui
|
import webui
|
||||||
webui.webui()
|
webui.webui()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
prepare_enviroment()
|
||||||
start_webui()
|
start_webui()
|
||||||
|
@ -10,13 +10,11 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader
|
from modules import devices, modelloader
|
||||||
from modules.bsrgan_model_arch import RRDBNet
|
from modules.bsrgan_model_arch import RRDBNet
|
||||||
from modules.paths import models_path
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
self.name = "BSRGAN"
|
self.name = "BSRGAN"
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
|
||||||
self.model_name = "BSRGAN 4x"
|
self.model_name = "BSRGAN 4x"
|
||||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||||
self.user_path = dirname
|
self.user_path = dirname
|
||||||
|
73
modules/deepbooru.py
Normal file
73
modules/deepbooru.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import os.path
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
from multiprocessing import get_context
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tf_and_return_tags(pil_image, threshold):
|
||||||
|
import deepdanbooru as dd
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
this_folder = os.path.dirname(__file__)
|
||||||
|
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
||||||
|
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
||||||
|
# there is no point importing these every time
|
||||||
|
import zipfile
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
||||||
|
model_path)
|
||||||
|
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
||||||
|
zip_ref.extractall(model_path)
|
||||||
|
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
||||||
|
|
||||||
|
tags = dd.project.load_tags_from_project(model_path)
|
||||||
|
model = dd.project.load_model_from_project(
|
||||||
|
model_path, compile_model=True
|
||||||
|
)
|
||||||
|
|
||||||
|
width = model.input_shape[2]
|
||||||
|
height = model.input_shape[1]
|
||||||
|
image = np.array(pil_image)
|
||||||
|
image = tf.image.resize(
|
||||||
|
image,
|
||||||
|
size=(height, width),
|
||||||
|
method=tf.image.ResizeMethod.AREA,
|
||||||
|
preserve_aspect_ratio=True,
|
||||||
|
)
|
||||||
|
image = image.numpy() # EagerTensor to np.array
|
||||||
|
image = dd.image.transform_and_pad_image(image, width, height)
|
||||||
|
image = image / 255.0
|
||||||
|
image_shape = image.shape
|
||||||
|
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
|
||||||
|
|
||||||
|
y = model.predict(image)[0]
|
||||||
|
|
||||||
|
result_dict = {}
|
||||||
|
|
||||||
|
for i, tag in enumerate(tags):
|
||||||
|
result_dict[tag] = y[i]
|
||||||
|
result_tags_out = []
|
||||||
|
result_tags_print = []
|
||||||
|
for tag in tags:
|
||||||
|
if result_dict[tag] >= threshold:
|
||||||
|
if tag.startswith("rating:"):
|
||||||
|
continue
|
||||||
|
result_tags_out.append(tag)
|
||||||
|
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
||||||
|
|
||||||
|
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
||||||
|
|
||||||
|
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
|
||||||
|
|
||||||
|
|
||||||
|
def subprocess_init_no_cuda():
|
||||||
|
import os
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||||
|
|
||||||
|
|
||||||
|
def get_deepbooru_tags(pil_image, threshold=0.5):
|
||||||
|
context = get_context('spawn')
|
||||||
|
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
|
||||||
|
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
|
||||||
|
ret = f.result() # will rethrow any exceptions
|
||||||
|
return ret
|
@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32")
|
|||||||
|
|
||||||
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
dtype_vae = torch.float16
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||||
@ -59,9 +60,12 @@ def randn_without_seed(shape):
|
|||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
def autocast():
|
def autocast(disable=False):
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
if disable:
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@ -5,9 +5,8 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.esrgam_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import shared, modelloader, images, devices
|
from modules import shared, modelloader, images, devices
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
@ -76,7 +75,6 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
self.model_name = "ESRGAN_4x"
|
self.model_name = "ESRGAN_4x"
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
self.user_path = dirname
|
self.user_path = dirname
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
scalers = []
|
scalers = []
|
||||||
@ -111,7 +109,7 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
|
pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||||
|
|
||||||
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
||||||
|
@ -29,7 +29,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
|||||||
if extras_mode == 1:
|
if extras_mode == 1:
|
||||||
#convert file to pillow image
|
#convert file to pillow image
|
||||||
for img in image_folder:
|
for img in image_folder:
|
||||||
image = Image.fromarray(np.array(Image.open(img)))
|
image = Image.open(img)
|
||||||
imageArr.append(image)
|
imageArr.append(image)
|
||||||
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
||||||
else:
|
else:
|
||||||
@ -98,6 +98,10 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
|||||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
||||||
forced_filename=image_name if opts.use_original_name_batch else None)
|
forced_filename=image_name if opts.use_original_name_batch else None)
|
||||||
|
|
||||||
|
if opts.enable_pnginfo:
|
||||||
|
image.info = existing_pnginfo
|
||||||
|
image.info["extras"] = info
|
||||||
|
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
@ -170,8 +174,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
|||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
theta_0 = primary_model['state_dict']
|
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
||||||
theta_1 = secondary_model['state_dict']
|
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
|
||||||
|
|
||||||
theta_funcs = {
|
theta_funcs = {
|
||||||
"Weighted Sum": weighted_sum,
|
"Weighted Sum": weighted_sum,
|
||||||
|
103
modules/hypernetwork.py
Normal file
103
modules/hypernetwork.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ldm.util import default
|
||||||
|
from modules import devices, shared
|
||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
|
class HypernetworkModule(torch.nn.Module):
|
||||||
|
def __init__(self, dim, state_dict):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear1 = torch.nn.Linear(dim, dim * 2)
|
||||||
|
self.linear2 = torch.nn.Linear(dim * 2, dim)
|
||||||
|
|
||||||
|
self.load_state_dict(state_dict, strict=True)
|
||||||
|
self.to(devices.device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + (self.linear2(self.linear1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class Hypernetwork:
|
||||||
|
filename = None
|
||||||
|
name = None
|
||||||
|
|
||||||
|
def __init__(self, filename):
|
||||||
|
self.filename = filename
|
||||||
|
self.name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
self.layers = {}
|
||||||
|
|
||||||
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
for size, sd in state_dict.items():
|
||||||
|
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def list_hypernetworks(path):
|
||||||
|
res = {}
|
||||||
|
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
|
||||||
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
res[name] = filename
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def load_hypernetwork(filename):
|
||||||
|
path = shared.hypernetworks.get(filename, None)
|
||||||
|
if path is not None:
|
||||||
|
print(f"Loading hypernetwork {filename}")
|
||||||
|
try:
|
||||||
|
shared.loaded_hypernetwork = Hypernetwork(path)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
else:
|
||||||
|
if shared.loaded_hypernetwork is not None:
|
||||||
|
print(f"Unloading hypernetwork")
|
||||||
|
|
||||||
|
shared.loaded_hypernetwork = None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_hypernetwork(hypernetwork, context):
|
||||||
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||||
|
|
||||||
|
if hypernetwork_layers is None:
|
||||||
|
return context, context
|
||||||
|
|
||||||
|
context_k = hypernetwork_layers[0](context)
|
||||||
|
context_v = hypernetwork_layers[1](context)
|
||||||
|
return context_k, context_v
|
||||||
|
|
||||||
|
|
||||||
|
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||||
|
k = self.to_k(context_k)
|
||||||
|
v = self.to_v(context_v)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||||
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(out)
|
@ -349,6 +349,38 @@ def get_next_sequence_number(path, basename):
|
|||||||
|
|
||||||
|
|
||||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
||||||
|
'''Save an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`PIL.Image`):
|
||||||
|
The image to be saved.
|
||||||
|
path (`str`):
|
||||||
|
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
||||||
|
basename (`str`):
|
||||||
|
The base filename which will be applied to `filename pattern`.
|
||||||
|
seed, prompt, short_filename,
|
||||||
|
extension (`str`):
|
||||||
|
Image file extension, default is `png`.
|
||||||
|
pngsectionname (`str`):
|
||||||
|
Specify the name of the section which `info` will be saved in.
|
||||||
|
info (`str` or `PngImagePlugin.iTXt`):
|
||||||
|
PNG info chunks.
|
||||||
|
existing_info (`dict`):
|
||||||
|
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
||||||
|
no_prompt:
|
||||||
|
TODO I don't know its meaning.
|
||||||
|
p (`StableDiffusionProcessing`)
|
||||||
|
forced_filename (`str`):
|
||||||
|
If specified, `basename` and filename pattern will be ignored.
|
||||||
|
save_to_dirs (bool):
|
||||||
|
If true, the image will be saved into a subdirectory of `path`.
|
||||||
|
|
||||||
|
Returns: (fullfn, txt_fullfn)
|
||||||
|
fullfn (`str`):
|
||||||
|
The full path of the saved imaged.
|
||||||
|
txt_fullfn (`str` or None):
|
||||||
|
If a text file is saved for this image, this will be its full path. Otherwise None.
|
||||||
|
'''
|
||||||
if short_filename or prompt is None or seed is None:
|
if short_filename or prompt is None or seed is None:
|
||||||
file_decoration = ""
|
file_decoration = ""
|
||||||
elif opts.save_to_dirs:
|
elif opts.save_to_dirs:
|
||||||
@ -424,7 +456,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
|
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
|
||||||
|
|
||||||
if opts.save_txt and info is not None:
|
if opts.save_txt and info is not None:
|
||||||
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
|
txt_fullfn = f"{fullfn_without_extension}.txt"
|
||||||
|
with open(txt_fullfn, "w", encoding="utf8") as file:
|
||||||
file.write(info + "\n")
|
file.write(info + "\n")
|
||||||
|
else:
|
||||||
|
txt_fullfn = None
|
||||||
|
|
||||||
return fullfn
|
return fullfn, txt_fullfn
|
||||||
|
@ -32,6 +32,8 @@ def process_batch(p, input_dir, output_dir, args):
|
|||||||
|
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
state.job = f"{i+1} out of {len(images)}"
|
state.job = f"{i+1} out of {len(images)}"
|
||||||
|
if state.skipped:
|
||||||
|
state.skipped = False
|
||||||
|
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
@ -140,11 +140,11 @@ class InterrogateModels:
|
|||||||
|
|
||||||
res = caption
|
res = caption
|
||||||
|
|
||||||
cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
||||||
|
|
||||||
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||||
with torch.no_grad(), precision_scope("cuda"):
|
with torch.no_grad(), precision_scope("cuda"):
|
||||||
image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
|
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
@ -7,13 +7,11 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.ldsr_model_arch import LDSR
|
from modules.ldsr_model_arch import LDSR
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.paths import models_path
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(Upscaler):
|
class UpscalerLDSR(Upscaler):
|
||||||
def __init__(self, user_path):
|
def __init__(self, user_path):
|
||||||
self.name = "LDSR"
|
self.name = "LDSR"
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
|
||||||
self.user_path = user_path
|
self.user_path = user_path
|
||||||
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||||
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import modules.safe
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
models_path = os.path.join(script_path, "models")
|
models_path = os.path.join(script_path, "models")
|
||||||
@ -12,6 +13,7 @@ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'),
|
|||||||
for possible_sd_path in possible_sd_paths:
|
for possible_sd_path in possible_sd_paths:
|
||||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
||||||
sd_path = os.path.abspath(possible_sd_path)
|
sd_path = os.path.abspath(possible_sd_path)
|
||||||
|
break
|
||||||
|
|
||||||
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
||||||
|
|
||||||
|
@ -46,6 +46,12 @@ def apply_color_correction(correction, image):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def get_correct_sampler(p):
|
||||||
|
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
|
||||||
|
return sd_samplers.samplers
|
||||||
|
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
||||||
|
return sd_samplers.samplers_for_img2img
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
@ -123,6 +129,7 @@ class Processed:
|
|||||||
self.index_of_first_image = index_of_first_image
|
self.index_of_first_image = index_of_first_image
|
||||||
self.styles = p.styles
|
self.styles = p.styles
|
||||||
self.job_timestamp = state.job_timestamp
|
self.job_timestamp = state.job_timestamp
|
||||||
|
self.clip_skip = opts.CLIP_stop_at_last_layers
|
||||||
|
|
||||||
self.eta = p.eta
|
self.eta = p.eta
|
||||||
self.ddim_discretize = p.ddim_discretize
|
self.ddim_discretize = p.ddim_discretize
|
||||||
@ -169,6 +176,7 @@ class Processed:
|
|||||||
"infotexts": self.infotexts,
|
"infotexts": self.infotexts,
|
||||||
"styles": self.styles,
|
"styles": self.styles,
|
||||||
"job_timestamp": self.job_timestamp,
|
"job_timestamp": self.job_timestamp,
|
||||||
|
"clip_skip": self.clip_skip,
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.dumps(obj)
|
return json.dumps(obj)
|
||||||
@ -199,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
||||||
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
||||||
# produce the same images as with two batches [100], [101].
|
# produce the same images as with two batches [100], [101].
|
||||||
if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
|
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
|
||||||
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
||||||
else:
|
else:
|
||||||
sampler_noises = None
|
sampler_noises = None
|
||||||
@ -239,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
if sampler_noises is not None:
|
if sampler_noises is not None:
|
||||||
cnt = p.sampler.number_of_needed_noises(p)
|
cnt = p.sampler.number_of_needed_noises(p)
|
||||||
|
|
||||||
|
if opts.eta_noise_seed_delta > 0:
|
||||||
|
torch.manual_seed(seed + opts.eta_noise_seed_delta)
|
||||||
|
|
||||||
for j in range(cnt):
|
for j in range(cnt):
|
||||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
||||||
|
|
||||||
@ -251,6 +262,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def decode_first_stage(model, x):
|
||||||
|
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||||
|
x = model.decode_first_stage(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_seed(seed):
|
def get_fixed_seed(seed):
|
||||||
if seed is None or seed == '' or seed == -1:
|
if seed is None or seed == '' or seed == -1:
|
||||||
return int(random.randrange(4294967294))
|
return int(random.randrange(4294967294))
|
||||||
@ -266,14 +284,18 @@ def fix_seed(p):
|
|||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": sd_samplers.samplers[p.sampler_index].name,
|
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Seed": all_seeds[index],
|
"Seed": all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
|
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
@ -281,6 +303,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||||
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
@ -312,6 +336,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
|
modules.sd_hijack.model_hijack.clear_comments()
|
||||||
|
|
||||||
comments = {}
|
comments = {}
|
||||||
|
|
||||||
@ -341,7 +366,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(all_prompts, all_seeds, all_subseeds)
|
p.init(all_prompts, all_seeds, all_subseeds)
|
||||||
|
|
||||||
@ -349,6 +374,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
|
||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
|
if state.skipped:
|
||||||
|
state.skipped = False
|
||||||
|
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -375,15 +403,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||||
|
|
||||||
if state.interrupted:
|
if state.interrupted or state.skipped:
|
||||||
|
|
||||||
# if we are interruped, sample returns just noise
|
# if we are interrupted, sample returns just noise
|
||||||
# use the image collected previously in sampler loop
|
# use the image collected previously in sampler loop
|
||||||
samples_ddim = shared.state.current_latent
|
samples_ddim = shared.state.current_latent
|
||||||
|
|
||||||
samples_ddim = samples_ddim.to(devices.dtype)
|
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||||
|
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||||
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
del samples_ddim
|
del samples_ddim
|
||||||
@ -436,7 +463,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
text = infotext(n, i)
|
text = infotext(n, i)
|
||||||
infotexts.append(text)
|
infotexts.append(text)
|
||||||
image.info["parameters"] = text
|
if opts.enable_pnginfo:
|
||||||
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
del x_samples_ddim
|
del x_samples_ddim
|
||||||
@ -455,7 +483,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if opts.return_grid:
|
if opts.return_grid:
|
||||||
text = infotext()
|
text = infotext()
|
||||||
infotexts.insert(0, text)
|
infotexts.insert(0, text)
|
||||||
grid.info["parameters"] = text
|
if opts.enable_pnginfo:
|
||||||
|
grid.info["parameters"] = text
|
||||||
output_images.insert(0, grid)
|
output_images.insert(0, grid)
|
||||||
index_of_first_image = 1
|
index_of_first_image = 1
|
||||||
|
|
||||||
@ -514,7 +543,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if self.scale_latent:
|
if self.scale_latent:
|
||||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||||
else:
|
else:
|
||||||
decoded_samples = self.sd_model.decode_first_stage(samples)
|
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||||
|
|
||||||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
||||||
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
||||||
|
@ -13,13 +13,14 @@ import lark
|
|||||||
|
|
||||||
schedule_parser = lark.Lark(r"""
|
schedule_parser = lark.Lark(r"""
|
||||||
!start: (prompt | /[][():]/+)*
|
!start: (prompt | /[][():]/+)*
|
||||||
prompt: (emphasized | scheduled | plain | WHITESPACE)*
|
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
||||||
!emphasized: "(" prompt ")"
|
!emphasized: "(" prompt ")"
|
||||||
| "(" prompt ":" prompt ")"
|
| "(" prompt ":" prompt ")"
|
||||||
| "[" prompt "]"
|
| "[" prompt "]"
|
||||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
||||||
|
alternate: "[" prompt ("|" prompt)+ "]"
|
||||||
WHITESPACE: /\s+/
|
WHITESPACE: /\s+/
|
||||||
plain: /([^\\\[\]():]|\\.)+/
|
plain: /([^\\\[\]():|]|\\.)+/
|
||||||
%import common.SIGNED_NUMBER -> NUMBER
|
%import common.SIGNED_NUMBER -> NUMBER
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@ -59,6 +60,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
tree.children[-1] *= steps
|
tree.children[-1] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||||
l.append(tree.children[-1])
|
l.append(tree.children[-1])
|
||||||
|
def alternate(self, tree):
|
||||||
|
l.extend(range(1, steps+1))
|
||||||
CollectSteps().visit(tree)
|
CollectSteps().visit(tree)
|
||||||
return sorted(set(l))
|
return sorted(set(l))
|
||||||
|
|
||||||
@ -67,6 +70,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
def scheduled(self, args):
|
def scheduled(self, args):
|
||||||
before, after, _, when = args
|
before, after, _, when = args
|
||||||
yield before or () if step <= when else after
|
yield before or () if step <= when else after
|
||||||
|
def alternate(self, args):
|
||||||
|
yield next(args[(step - 1)%len(args)])
|
||||||
def start(self, args):
|
def start(self, args):
|
||||||
def flatten(x):
|
def flatten(x):
|
||||||
if type(x) == str:
|
if type(x) == str:
|
||||||
@ -239,6 +244,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
|||||||
|
|
||||||
conds_list.append(conds_for_batch)
|
conds_list.append(conds_for_batch)
|
||||||
|
|
||||||
|
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
|
||||||
|
# and won't be able to torch.stack them. So this fixes that.
|
||||||
|
token_count = max([x.shape[0] for x in tensors])
|
||||||
|
for i in range(len(tensors)):
|
||||||
|
if tensors[i].shape[0] != token_count:
|
||||||
|
last_vector = tensors[i][-1:]
|
||||||
|
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||||
|
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||||
|
|
||||||
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,14 +8,12 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.shared import cmd_opts, opts
|
from modules.shared import cmd_opts, opts
|
||||||
|
|
||||||
|
|
||||||
class UpscalerRealESRGAN(Upscaler):
|
class UpscalerRealESRGAN(Upscaler):
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.name = "RealESRGAN"
|
self.name = "RealESRGAN"
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
|
||||||
self.user_path = path
|
self.user_path = path
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
try:
|
||||||
|
93
modules/safe.py
Normal file
93
modules/safe.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# this code is adapted from the script contributed by anon from /h/
|
||||||
|
|
||||||
|
import io
|
||||||
|
import pickle
|
||||||
|
import collections
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy
|
||||||
|
import _codecs
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
|
||||||
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
|
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||||
|
|
||||||
|
|
||||||
|
def encode(*args):
|
||||||
|
out = _codecs.encode(*args)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def persistent_load(self, saved_id):
|
||||||
|
assert saved_id[0] == 'storage'
|
||||||
|
return TypedStorage()
|
||||||
|
|
||||||
|
def find_class(self, module, name):
|
||||||
|
if module == 'collections' and name == 'OrderedDict':
|
||||||
|
return getattr(collections, name)
|
||||||
|
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||||
|
return getattr(torch._utils, name)
|
||||||
|
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
|
||||||
|
return getattr(torch, name)
|
||||||
|
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||||
|
return getattr(torch.nn.modules.container, name)
|
||||||
|
if module == 'numpy.core.multiarray' and name == 'scalar':
|
||||||
|
return numpy.core.multiarray.scalar
|
||||||
|
if module == 'numpy' and name == 'dtype':
|
||||||
|
return numpy.dtype
|
||||||
|
if module == '_codecs' and name == 'encode':
|
||||||
|
return encode
|
||||||
|
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||||
|
import pytorch_lightning.callbacks
|
||||||
|
return pytorch_lightning.callbacks.model_checkpoint
|
||||||
|
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||||
|
import pytorch_lightning.callbacks.model_checkpoint
|
||||||
|
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||||
|
if module == "__builtin__" and name == 'set':
|
||||||
|
return set
|
||||||
|
|
||||||
|
# Forbid everything else.
|
||||||
|
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||||
|
|
||||||
|
|
||||||
|
def check_pt(filename):
|
||||||
|
try:
|
||||||
|
|
||||||
|
# new pytorch format is a zip file
|
||||||
|
with zipfile.ZipFile(filename) as z:
|
||||||
|
with z.open('archive/data.pkl') as file:
|
||||||
|
unpickler = RestrictedUnpickler(file)
|
||||||
|
unpickler.load()
|
||||||
|
|
||||||
|
except zipfile.BadZipfile:
|
||||||
|
|
||||||
|
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||||
|
with open(filename, "rb") as file:
|
||||||
|
unpickler = RestrictedUnpickler(file)
|
||||||
|
for i in range(5):
|
||||||
|
unpickler.load()
|
||||||
|
|
||||||
|
|
||||||
|
def load(filename, *args, **kwargs):
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not shared.cmd_opts.disable_safe_unpickle:
|
||||||
|
check_pt(filename)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||||
|
print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return unsafe_torch_load(filename, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
unsafe_torch_load = torch.load
|
||||||
|
torch.load = load
|
@ -9,14 +9,12 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader
|
from modules import devices, modelloader
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.scunet_model_arch import SCUNet as net
|
from modules.scunet_model_arch import SCUNet as net
|
||||||
|
|
||||||
|
|
||||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
self.name = "ScuNET"
|
self.name = "ScuNET"
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
|
||||||
self.model_name = "ScuNET GAN"
|
self.model_name = "ScuNET GAN"
|
||||||
self.model_name2 = "ScuNET PSNR"
|
self.model_name2 = "ScuNET PSNR"
|
||||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
|
||||||
|
@ -40,7 +40,7 @@ class WMSA(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
attn_mask: should be (1 1 w p p),
|
attn_mask: should be (1 1 w p p),
|
||||||
"""
|
"""
|
||||||
# supporting sqaure.
|
# supporting square.
|
||||||
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
||||||
if self.type == 'W':
|
if self.type == 'W':
|
||||||
return attn_mask
|
return attn_mask
|
||||||
@ -65,7 +65,7 @@ class WMSA(nn.Module):
|
|||||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||||
h_windows = x.size(1)
|
h_windows = x.size(1)
|
||||||
w_windows = x.size(2)
|
w_windows = x.size(2)
|
||||||
# sqaure validation
|
# square validation
|
||||||
# assert h_windows == w_windows
|
# assert h_windows == w_windows
|
||||||
|
|
||||||
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
||||||
|
@ -18,15 +18,20 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations():
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
|
||||||
|
print("Applying xformers cross attention optimization.")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||||
|
elif cmd_opts.opt_split_attention_v1:
|
||||||
|
print("Applying v1 cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
|
print("Applying cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||||
|
|
||||||
@ -39,6 +44,10 @@ def undo_optimizations():
|
|||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(token_count):
|
||||||
|
return math.ceil(max(token_count, 1) / 75) * 75
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
fixes = None
|
fixes = None
|
||||||
comments = []
|
comments = []
|
||||||
@ -84,10 +93,12 @@ class StableDiffusionModelHijack:
|
|||||||
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
||||||
layer.padding_mode = 'circular' if enable else 'zeros'
|
layer.padding_mode = 'circular' if enable else 'zeros'
|
||||||
|
|
||||||
|
def clear_comments(self):
|
||||||
|
self.comments = []
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
max_length = self.clip.max_length - 2
|
|
||||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||||
return remade_batch_tokens[0], token_count, max_length
|
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
@ -96,9 +107,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.hijack: StableDiffusionModelHijack = hijack
|
self.hijack: StableDiffusionModelHijack = hijack
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
self.max_length = wrapped.max_length
|
|
||||||
self.token_mults = {}
|
self.token_mults = {}
|
||||||
|
|
||||||
|
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||||
|
|
||||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||||
for text, ident in tokens_with_parens:
|
for text, ident in tokens_with_parens:
|
||||||
mult = 1.0
|
mult = 1.0
|
||||||
@ -116,9 +128,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
maxlen = self.wrapped.max_length
|
|
||||||
|
|
||||||
if opts.enable_emphasis:
|
if opts.enable_emphasis:
|
||||||
parsed = prompt_parser.parse_prompt_attention(line)
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
@ -130,6 +140,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
fixes = []
|
fixes = []
|
||||||
remade_tokens = []
|
remade_tokens = []
|
||||||
multipliers = []
|
multipliers = []
|
||||||
|
last_comma = -1
|
||||||
|
|
||||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
i = 0
|
i = 0
|
||||||
@ -138,31 +149,44 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
|
|
||||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
|
if token == self.comma_token:
|
||||||
|
last_comma = len(remade_tokens)
|
||||||
|
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||||
|
last_comma += 1
|
||||||
|
reloc_tokens = remade_tokens[last_comma:]
|
||||||
|
reloc_mults = multipliers[last_comma:]
|
||||||
|
|
||||||
|
remade_tokens = remade_tokens[:last_comma]
|
||||||
|
length = len(remade_tokens)
|
||||||
|
|
||||||
|
rem = int(math.ceil(length / 75)) * 75 - length
|
||||||
|
remade_tokens += [id_end] * rem + reloc_tokens
|
||||||
|
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||||
|
|
||||||
if embedding is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(weight)
|
multipliers.append(weight)
|
||||||
i += 1
|
i += 1
|
||||||
else:
|
else:
|
||||||
emb_len = int(embedding.vec.shape[0])
|
emb_len = int(embedding.vec.shape[0])
|
||||||
fixes.append((len(remade_tokens), embedding))
|
iteration = len(remade_tokens) // 75
|
||||||
|
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||||
|
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||||
|
remade_tokens += [id_end] * rem
|
||||||
|
multipliers += [1.0] * rem
|
||||||
|
iteration += 1
|
||||||
|
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||||
remade_tokens += [0] * emb_len
|
remade_tokens += [0] * emb_len
|
||||||
multipliers += [weight] * emb_len
|
multipliers += [weight] * emb_len
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
i += embedding_length_in_tokens
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
|
||||||
ovf = remade_tokens[maxlen - 2:]
|
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
|
||||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
token_count = len(remade_tokens)
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||||
|
|
||||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
remade_tokens = remade_tokens + [id_end] * tokens_to_add
|
||||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
multipliers = multipliers + [1.0] * tokens_to_add
|
||||||
|
|
||||||
return remade_tokens, fixes, multipliers, token_count
|
return remade_tokens, fixes, multipliers, token_count
|
||||||
|
|
||||||
@ -179,7 +203,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
if line in cache:
|
if line in cache:
|
||||||
remade_tokens, fixes, multipliers = cache[line]
|
remade_tokens, fixes, multipliers = cache[line]
|
||||||
else:
|
else:
|
||||||
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||||
|
token_count = max(current_token_count, token_count)
|
||||||
|
|
||||||
cache[line] = (remade_tokens, fixes, multipliers)
|
cache[line] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
@ -193,7 +218,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
def process_text_old(self, text):
|
def process_text_old(self, text):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
maxlen = self.wrapped.max_length
|
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||||
used_custom_terms = []
|
used_custom_terms = []
|
||||||
remade_batch_tokens = []
|
remade_batch_tokens = []
|
||||||
overflowing_words = []
|
overflowing_words = []
|
||||||
@ -258,24 +283,62 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
|
use_old = opts.use_old_emphasis_implementation
|
||||||
if opts.use_old_emphasis_implementation:
|
if use_old:
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||||
else:
|
else:
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||||
|
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.comments += hijack_comments
|
||||||
self.hijack.comments = hijack_comments
|
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
|
if use_old:
|
||||||
|
self.hijack.fixes = hijack_fixes
|
||||||
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||||
|
|
||||||
|
z = None
|
||||||
|
i = 0
|
||||||
|
while max(map(len, remade_batch_tokens)) != 0:
|
||||||
|
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||||
|
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||||
|
|
||||||
|
self.hijack.fixes = []
|
||||||
|
for unfiltered in hijack_fixes:
|
||||||
|
fixes = []
|
||||||
|
for fix in unfiltered:
|
||||||
|
if fix[0] == i:
|
||||||
|
fixes.append(fix[1])
|
||||||
|
self.hijack.fixes.append(fixes)
|
||||||
|
|
||||||
|
z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
|
||||||
|
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||||
|
|
||||||
|
remade_batch_tokens = rem_tokens
|
||||||
|
batch_multipliers = rem_multipliers
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
|
if not opts.use_old_emphasis_implementation:
|
||||||
|
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
||||||
|
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||||
z = outputs.last_hidden_state
|
|
||||||
|
if opts.CLIP_stop_at_last_layers > 1:
|
||||||
|
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||||
|
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||||
|
else:
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
|
||||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
batch_multipliers = torch.asarray(batch_multipliers).to(device)
|
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||||
|
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
|
||||||
original_mean = z.mean()
|
original_mean = z.mean()
|
||||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
new_mean = z.mean()
|
new_mean = z.mean()
|
||||||
|
@ -1,24 +1,39 @@
|
|||||||
import math
|
import math
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared, hypernetwork
|
||||||
|
|
||||||
|
|
||||||
|
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||||
|
try:
|
||||||
|
import xformers.ops
|
||||||
|
shared.xformers_available = True
|
||||||
|
except Exception:
|
||||||
|
print("Cannot import xformers", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
|
||||||
v = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||||
|
k_in = self.to_k(context_k)
|
||||||
|
v_in = self.to_v(context_v)
|
||||||
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||||
for i in range(0, q.shape[0], 2):
|
for i in range(0, q.shape[0], 2):
|
||||||
@ -31,6 +46,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
del s2
|
del s2
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
del r1
|
del r1
|
||||||
@ -38,21 +54,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
|
||||||
# taken from https://github.com/Doggettx/stable-diffusion
|
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
||||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||||
|
k_in = self.to_k(context_k)
|
||||||
if hypernetwork_layers is not None:
|
v_in = self.to_v(context_v)
|
||||||
k_in = self.to_k(hypernetwork_layers[0](context))
|
|
||||||
v_in = self.to_v(hypernetwork_layers[1](context))
|
|
||||||
else:
|
|
||||||
k_in = self.to_k(context)
|
|
||||||
v_in = self.to_v(context)
|
|
||||||
|
|
||||||
k_in *= self.scale
|
k_in *= self.scale
|
||||||
|
|
||||||
@ -104,6 +115,22 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
q_in = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||||
|
k_in = self.to_k(context_k)
|
||||||
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||||
|
|
||||||
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
def cross_attention_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -166,3 +193,16 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
h3 += x
|
h3 += x
|
||||||
|
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
|
def xformers_attnblock_forward(self, x):
|
||||||
|
try:
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q1 = self.q(h_).contiguous()
|
||||||
|
k1 = self.k(h_).contiguous()
|
||||||
|
v = self.v(h_).contiguous()
|
||||||
|
out = xformers.ops.memory_efficient_attention(q1, k1, v)
|
||||||
|
out = self.proj_out(out)
|
||||||
|
return x + out
|
||||||
|
except NotImplementedError:
|
||||||
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
@ -5,7 +5,6 @@ from collections import namedtuple
|
|||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader, devices
|
from modules import shared, modelloader, devices
|
||||||
@ -14,7 +13,7 @@ from modules.paths import models_path
|
|||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -63,14 +62,20 @@ def list_models():
|
|||||||
if os.path.exists(cmd_ckpt):
|
if os.path.exists(cmd_ckpt):
|
||||||
h = model_hash(cmd_ckpt)
|
h = model_hash(cmd_ckpt)
|
||||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
|
||||||
shared.opts.data['sd_model_checkpoint'] = title
|
shared.opts.data['sd_model_checkpoint'] = title
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
for filename in model_list:
|
for filename in model_list:
|
||||||
h = model_hash(filename)
|
h = model_hash(filename)
|
||||||
title, short_model_name = modeltitle(filename, h)
|
title, short_model_name = modeltitle(filename, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
|
||||||
|
basename, _ = os.path.splitext(filename)
|
||||||
|
config = basename + ".yaml"
|
||||||
|
if not os.path.exists(config):
|
||||||
|
config = shared.cmd_opts.config
|
||||||
|
|
||||||
|
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(searchString):
|
def get_closet_checkpoint_match(searchString):
|
||||||
@ -116,13 +121,24 @@ def select_checkpoint():
|
|||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_file, sd_model_hash):
|
def get_state_dict_from_checkpoint(pl_sd):
|
||||||
|
if "state_dict" in pl_sd:
|
||||||
|
return pl_sd["state_dict"]
|
||||||
|
|
||||||
|
return pl_sd
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_weights(model, checkpoint_info):
|
||||||
|
checkpoint_file = checkpoint_info.filename
|
||||||
|
sd_model_hash = checkpoint_info.hash
|
||||||
|
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||||
|
|
||||||
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
|
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||||
|
|
||||||
model.load_state_dict(sd, strict=False)
|
model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
@ -133,8 +149,13 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
|
|||||||
model.half()
|
model.half()
|
||||||
|
|
||||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
|
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||||
|
|
||||||
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
|
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
|
||||||
|
|
||||||
|
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
|
||||||
|
vae_file = shared.cmd_opts.vae_path
|
||||||
|
|
||||||
if os.path.exists(vae_file):
|
if os.path.exists(vae_file):
|
||||||
print(f"Loading VAE weights from: {vae_file}")
|
print(f"Loading VAE weights from: {vae_file}")
|
||||||
vae_ckpt = torch.load(vae_file, map_location="cpu")
|
vae_ckpt = torch.load(vae_file, map_location="cpu")
|
||||||
@ -142,17 +163,23 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
|
|||||||
|
|
||||||
model.first_stage_model.load_state_dict(vae_dict)
|
model.first_stage_model.load_state_dict(vae_dict)
|
||||||
|
|
||||||
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpint = checkpoint_file
|
model.sd_model_checkpoint = checkpoint_file
|
||||||
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = select_checkpoint()
|
checkpoint_info = select_checkpoint()
|
||||||
|
|
||||||
sd_config = OmegaConf.load(shared.cmd_opts.config)
|
if checkpoint_info.config != shared.cmd_opts.config:
|
||||||
|
print(f"Loading config from: {checkpoint_info.config}")
|
||||||
|
|
||||||
|
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||||
@ -171,9 +198,13 @@ def reload_model_weights(sd_model, info=None):
|
|||||||
from modules import lowvram, devices, sd_hijack
|
from modules import lowvram, devices, sd_hijack
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
||||||
|
shared.sd_model = load_model()
|
||||||
|
return shared.sd_model
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
@ -181,7 +212,7 @@ def reload_model_weights(sd_model, info=None):
|
|||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import inspect
|
|||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser, devices, processing
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
|
|||||||
|
|
||||||
|
|
||||||
def sample_to_image(samples):
|
def sample_to_image(samples):
|
||||||
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
|
x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
|
||||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
@ -106,7 +106,7 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
|||||||
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
||||||
|
|
||||||
for x in seq:
|
for x in seq:
|
||||||
if state.interrupted:
|
if state.interrupted or state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield x
|
yield x
|
||||||
@ -142,6 +142,16 @@ class VanillaStableDiffusionSampler:
|
|||||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||||
cond = tensor
|
cond = tensor
|
||||||
|
|
||||||
|
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||||
|
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||||
|
# not 100% correct but should work well enough
|
||||||
|
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||||
|
last_vector = unconditional_conditioning[:, -1:]
|
||||||
|
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||||
|
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||||
|
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||||
|
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
@ -171,7 +181,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
|
|
||||||
self.initialize(p)
|
self.initialize(p)
|
||||||
|
|
||||||
# existing code fails with cetain step counts, like 9
|
# existing code fails with certain step counts, like 9
|
||||||
try:
|
try:
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -194,7 +204,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
|
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
# existing code fails with cetin step counts, like 9
|
# existing code fails with certain step counts, like 9
|
||||||
try:
|
try:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -221,18 +231,29 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
cond_in = torch.cat([tensor, uncond])
|
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if tensor.shape[1] == uncond.shape[1]:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
|
if shared.batch_cond_uncond:
|
||||||
|
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = a + batch_size
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = a + batch_size
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
|
||||||
|
|
||||||
denoised_uncond = x_out[-batch_size:]
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
|
||||||
|
|
||||||
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
denoised = torch.clone(denoised_uncond)
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
|
||||||
for i, conds in enumerate(conds_list):
|
for i, conds in enumerate(conds_list):
|
||||||
@ -254,7 +275,7 @@ def extended_trange(sampler, count, *args, **kwargs):
|
|||||||
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
||||||
|
|
||||||
for x in seq:
|
for x in seq:
|
||||||
if state.interrupted:
|
if state.interrupted or state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
if sampler.stop_at is not None and x > sampler.stop_at:
|
if sampler.stop_at is not None and x > sampler.stop_at:
|
||||||
|
@ -13,7 +13,7 @@ import modules.memmon
|
|||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import sd_samplers
|
from modules import sd_samplers, hypernetwork
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
@ -25,10 +25,10 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
|
|||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||||
|
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||||
@ -44,6 +44,9 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
|
|||||||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
||||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
||||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
||||||
|
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||||
|
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||||
|
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||||
@ -63,6 +66,8 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
|
|||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||||
|
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
||||||
|
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||||
|
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
@ -74,21 +79,15 @@ device = devices.device
|
|||||||
|
|
||||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||||
|
xformers_available = False
|
||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
|
|
||||||
|
hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks'))
|
||||||
def reload_hypernetworks():
|
loaded_hypernetwork = None
|
||||||
from modules.hypernetwork import hypernetwork
|
|
||||||
hypernetworks.clear()
|
|
||||||
hypernetworks.update(hypernetwork.load_hypernetworks(cmd_opts.hypernetwork_dir))
|
|
||||||
|
|
||||||
|
|
||||||
hypernetworks = {}
|
|
||||||
hypernetwork = None
|
|
||||||
|
|
||||||
|
|
||||||
class State:
|
class State:
|
||||||
|
skipped = False
|
||||||
interrupted = False
|
interrupted = False
|
||||||
job = ""
|
job = ""
|
||||||
job_no = 0
|
job_no = 0
|
||||||
@ -101,6 +100,9 @@ class State:
|
|||||||
current_image_sampling_step = 0
|
current_image_sampling_step = 0
|
||||||
textinfo = None
|
textinfo = None
|
||||||
|
|
||||||
|
def skip(self):
|
||||||
|
self.skipped = True
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
|
|
||||||
@ -123,8 +125,6 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
|||||||
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
# This was moved to webui.py with the other model "setup" calls.
|
|
||||||
# modules.sd_models.list_models()
|
|
||||||
|
|
||||||
|
|
||||||
def realesrgan_models_names():
|
def realesrgan_models_names():
|
||||||
@ -133,18 +133,19 @@ def realesrgan_models_names():
|
|||||||
|
|
||||||
|
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False):
|
||||||
self.default = default
|
self.default = default
|
||||||
self.label = label
|
self.label = label
|
||||||
self.component = component
|
self.component = component
|
||||||
self.component_args = component_args
|
self.component_args = component_args
|
||||||
self.onchange = onchange
|
self.onchange = onchange
|
||||||
self.section = None
|
self.section = None
|
||||||
|
self.show_on_main_page = show_on_main_page
|
||||||
|
|
||||||
|
|
||||||
def options_section(section_identifer, options_dict):
|
def options_section(section_identifier, options_dict):
|
||||||
for k, v in options_dict.items():
|
for k, v in options_dict.items():
|
||||||
v.section = section_identifer
|
v.section = section_identifier
|
||||||
|
|
||||||
return options_dict
|
return options_dict
|
||||||
|
|
||||||
@ -172,6 +173,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
|
|
||||||
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
||||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||||
|
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||||
@ -216,7 +218,7 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
|
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
@ -225,7 +227,9 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
|
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||||
|
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -236,17 +240,19 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||||
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
|
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
|
||||||
|
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
|
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -258,6 +264,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.shared import cmd_opts, opts, device
|
from modules.shared import cmd_opts, opts, device
|
||||||
from modules.swinir_model_arch import SwinIR as net
|
from modules.swinir_model_arch import SwinIR as net
|
||||||
|
from modules.swinir_model_arch_v2 import Swin2SR as net2
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
precision_scope = (
|
precision_scope = (
|
||||||
@ -25,7 +25,6 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
||||||
"-L_x4_GAN.pth "
|
"-L_x4_GAN.pth "
|
||||||
self.model_name = "SwinIR 4x"
|
self.model_name = "SwinIR 4x"
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
|
||||||
self.user_path = dirname
|
self.user_path = dirname
|
||||||
super().__init__()
|
super().__init__()
|
||||||
scalers = []
|
scalers = []
|
||||||
@ -59,22 +58,42 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
filename = path
|
filename = path
|
||||||
if filename is None or not os.path.exists(filename):
|
if filename is None or not os.path.exists(filename):
|
||||||
return None
|
return None
|
||||||
model = net(
|
if filename.endswith(".v2.pth"):
|
||||||
|
model = net2(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
window_size=8,
|
window_size=8,
|
||||||
img_range=1.0,
|
img_range=1.0,
|
||||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
depths=[6, 6, 6, 6, 6, 6],
|
||||||
embed_dim=240,
|
embed_dim=180,
|
||||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
num_heads=[6, 6, 6, 6, 6, 6],
|
||||||
mlp_ratio=2,
|
mlp_ratio=2,
|
||||||
upsampler="nearest+conv",
|
upsampler="nearest+conv",
|
||||||
resi_connection="3conv",
|
resi_connection="1conv",
|
||||||
)
|
)
|
||||||
|
params = None
|
||||||
|
else:
|
||||||
|
model = net(
|
||||||
|
upscale=scale,
|
||||||
|
in_chans=3,
|
||||||
|
img_size=64,
|
||||||
|
window_size=8,
|
||||||
|
img_range=1.0,
|
||||||
|
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||||
|
embed_dim=240,
|
||||||
|
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||||
|
mlp_ratio=2,
|
||||||
|
upsampler="nearest+conv",
|
||||||
|
resi_connection="3conv",
|
||||||
|
)
|
||||||
|
params = "params_ema"
|
||||||
|
|
||||||
pretrained_model = torch.load(filename)
|
pretrained_model = torch.load(filename)
|
||||||
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
if params is not None:
|
||||||
|
model.load_state_dict(pretrained_model[params], strict=True)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(pretrained_model, strict=True)
|
||||||
if not cmd_opts.no_half:
|
if not cmd_opts.no_half:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
return model
|
return model
|
||||||
|
@ -166,7 +166,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim (int): Number of input channels.
|
dim (int): Number of input channels.
|
||||||
input_resolution (tuple[int]): Input resulotion.
|
input_resolution (tuple[int]): Input resolution.
|
||||||
num_heads (int): Number of attention heads.
|
num_heads (int): Number of attention heads.
|
||||||
window_size (int): Window size.
|
window_size (int): Window size.
|
||||||
shift_size (int): Shift size for SW-MSA.
|
shift_size (int): Shift size for SW-MSA.
|
||||||
|
1017
modules/swinir_model_arch_v2.py
Normal file
1017
modules/swinir_model_arch_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -15,11 +15,10 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
|||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
|
||||||
|
|
||||||
self.placeholder_token = placeholder_token
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
self.size = size
|
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
@ -7,8 +7,9 @@ import tqdm
|
|||||||
from modules import shared, images
|
from modules import shared, images
|
||||||
|
|
||||||
|
|
||||||
def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
|
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
|
||||||
size = 512
|
width = process_width
|
||||||
|
height = process_height
|
||||||
src = os.path.abspath(process_src)
|
src = os.path.abspath(process_src)
|
||||||
dst = os.path.abspath(process_dst)
|
dst = os.path.abspath(process_dst)
|
||||||
|
|
||||||
@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
|
|||||||
is_wide = ratio < 1 / 1.35
|
is_wide = ratio < 1 / 1.35
|
||||||
|
|
||||||
if process_split and is_tall:
|
if process_split and is_tall:
|
||||||
img = img.resize((size, size * img.height // img.width))
|
img = img.resize((width, height * img.height // img.width))
|
||||||
|
|
||||||
top = img.crop((0, 0, size, size))
|
top = img.crop((0, 0, width, height))
|
||||||
save_pic(top, index)
|
save_pic(top, index)
|
||||||
|
|
||||||
bot = img.crop((0, img.height - size, size, img.height))
|
bot = img.crop((0, img.height - height, width, img.height))
|
||||||
save_pic(bot, index)
|
save_pic(bot, index)
|
||||||
elif process_split and is_wide:
|
elif process_split and is_wide:
|
||||||
img = img.resize((size * img.width // img.height, size))
|
img = img.resize((width * img.width // img.height, height))
|
||||||
|
|
||||||
left = img.crop((0, 0, size, size))
|
left = img.crop((0, 0, width, height))
|
||||||
save_pic(left, index)
|
save_pic(left, index)
|
||||||
|
|
||||||
right = img.crop((img.width - size, 0, img.width, size))
|
right = img.crop((img.width - width, 0, img.width, height))
|
||||||
save_pic(right, index)
|
save_pic(right, index)
|
||||||
else:
|
else:
|
||||||
img = images.resize_image(1, img, size, size)
|
img = images.resize_image(1, img, width, height)
|
||||||
save_pic(img, index)
|
save_pic(img, index)
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
@ -182,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||||
|
|
||||||
hijack = sd_hijack.model_hijack
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
@ -200,6 +200,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
|
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
||||||
|
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
for i, (x, text) in pbar:
|
for i, (x, text) in pbar:
|
||||||
embedding.step = i + ititial_step
|
embedding.step = i + ititial_step
|
||||||
@ -223,7 +226,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
pbar.set_description(f"loss: {losses.mean():.7f}")
|
epoch_num = embedding.step // epoch_len
|
||||||
|
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
|
||||||
|
|
||||||
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
@ -236,6 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
prompt=text,
|
prompt=text,
|
||||||
steps=20,
|
steps=20,
|
||||||
|
height=training_height,
|
||||||
|
width=training_width,
|
||||||
do_not_save_grid=True,
|
do_not_save_grid=True,
|
||||||
do_not_save_samples=True,
|
do_not_save_samples=True,
|
||||||
)
|
)
|
||||||
|
172
modules/ui.py
172
modules/ui.py
@ -25,6 +25,8 @@ import gradio.routes
|
|||||||
from modules import sd_hijack
|
from modules import sd_hijack
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
if cmd_opts.deepdanbooru:
|
||||||
|
from modules.deepbooru import get_deepbooru_tags
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
@ -39,7 +41,7 @@ from modules.images import save_image
|
|||||||
import modules.textual_inversion.ui
|
import modules.textual_inversion.ui
|
||||||
import modules.hypernetwork.ui
|
import modules.hypernetwork.ui
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
mimetypes.add_type('application/javascript', '.js')
|
||||||
|
|
||||||
@ -99,11 +101,12 @@ def send_gradio_gallery_to_image(x):
|
|||||||
return image_from_url_text(x[0])
|
return image_from_url_text(x[0])
|
||||||
|
|
||||||
|
|
||||||
def save_files(js_data, images, index):
|
def save_files(js_data, images, do_make_zip, index):
|
||||||
import csv
|
import csv
|
||||||
filenames = []
|
filenames = []
|
||||||
|
fullfns = []
|
||||||
|
|
||||||
#quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it
|
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
||||||
class MyObject:
|
class MyObject:
|
||||||
def __init__(self, d=None):
|
def __init__(self, d=None):
|
||||||
if d is not None:
|
if d is not None:
|
||||||
@ -138,14 +141,29 @@ def save_files(js_data, images, index):
|
|||||||
is_grid = image_index < p.index_of_first_image
|
is_grid = image_index < p.index_of_first_image
|
||||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||||
|
|
||||||
fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||||
|
|
||||||
filename = os.path.relpath(fullfn, path)
|
filename = os.path.relpath(fullfn, path)
|
||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
|
fullfns.append(fullfn)
|
||||||
|
if txt_fullfn:
|
||||||
|
filenames.append(os.path.basename(txt_fullfn))
|
||||||
|
fullfns.append(txt_fullfn)
|
||||||
|
|
||||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||||
|
|
||||||
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
# Make Zip
|
||||||
|
if do_make_zip:
|
||||||
|
zip_filepath = os.path.join(path, "images.zip")
|
||||||
|
|
||||||
|
from zipfile import ZipFile
|
||||||
|
with ZipFile(zip_filepath, "w") as zip_file:
|
||||||
|
for i in range(len(fullfns)):
|
||||||
|
with open(fullfns[i], mode="rb") as f:
|
||||||
|
zip_file.writestr(filenames[i], f.read())
|
||||||
|
fullfns.insert(0, zip_filepath)
|
||||||
|
|
||||||
|
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func, extra_outputs=None):
|
def wrap_gradio_call(func, extra_outputs=None):
|
||||||
@ -192,6 +210,7 @@ def wrap_gradio_call(func, extra_outputs=None):
|
|||||||
# last item is always HTML
|
# last item is always HTML
|
||||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||||
|
|
||||||
|
shared.state.skipped = False
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
|
|
||||||
@ -292,6 +311,11 @@ def interrogate(image):
|
|||||||
return gr_show(True) if prompt is None else prompt
|
return gr_show(True) if prompt is None else prompt
|
||||||
|
|
||||||
|
|
||||||
|
def interrogate_deepbooru(image):
|
||||||
|
prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold)
|
||||||
|
return gr_show(True) if prompt is None else prompt
|
||||||
|
|
||||||
|
|
||||||
def create_seed_inputs():
|
def create_seed_inputs():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
@ -412,24 +436,36 @@ def create_toprow(is_img2img):
|
|||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
|
||||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
||||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||||
|
|
||||||
|
skip.click(
|
||||||
|
fn=lambda: shared.state.skip(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
interrupt.click(
|
interrupt.click(
|
||||||
fn=lambda: shared.state.interrupt(),
|
fn=lambda: shared.state.interrupt(),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row(scale=1):
|
||||||
if is_img2img:
|
if is_img2img:
|
||||||
interrogate = gr.Button('Interrogate', elem_id="interrogate")
|
interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||||
|
if cmd_opts.deepdanbooru:
|
||||||
|
deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||||
|
else:
|
||||||
|
deepbooru = None
|
||||||
else:
|
else:
|
||||||
interrogate = None
|
interrogate = None
|
||||||
|
deepbooru = None
|
||||||
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
||||||
save_style = gr.Button('Create style', elem_id="style_create")
|
save_style = gr.Button('Create style', elem_id="style_create")
|
||||||
|
|
||||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
|
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||||
@ -458,7 +494,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
||||||
with gr.Row(elem_id='txt2img_progress_row'):
|
with gr.Row(elem_id='txt2img_progress_row'):
|
||||||
@ -489,7 +525,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
|
||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
|
||||||
|
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
|
||||||
@ -514,6 +550,12 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||||
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
generation_info = gr.Textbox(visible=False)
|
generation_info = gr.Textbox(visible=False)
|
||||||
@ -563,13 +605,15 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
|
|
||||||
save.click(
|
save.click(
|
||||||
fn=wrap_gradio_call(save_files),
|
fn=wrap_gradio_call(save_files),
|
||||||
_js="(x, y, z) => [x, y, selected_gallery_index()]",
|
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||||
inputs=[
|
inputs=[
|
||||||
generation_info,
|
generation_info,
|
||||||
txt2img_gallery,
|
txt2img_gallery,
|
||||||
|
do_make_zip,
|
||||||
html_info,
|
html_info,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
download_files,
|
||||||
html_info,
|
html_info,
|
||||||
html_info,
|
html_info,
|
||||||
html_info,
|
html_info,
|
||||||
@ -610,7 +654,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
|
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
|
||||||
|
|
||||||
with gr.Row(elem_id='img2img_progress_row'):
|
with gr.Row(elem_id='img2img_progress_row'):
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
@ -667,7 +711,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
tiling = gr.Checkbox(label='Tiling', value=False)
|
tiling = gr.Checkbox(label='Tiling', value=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
|
||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
@ -694,6 +738,12 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||||
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
generation_info = gr.Textbox(visible=False)
|
generation_info = gr.Textbox(visible=False)
|
||||||
@ -767,15 +817,24 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
outputs=[img2img_prompt],
|
outputs=[img2img_prompt],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cmd_opts.deepdanbooru:
|
||||||
|
img2img_deepbooru.click(
|
||||||
|
fn=interrogate_deepbooru,
|
||||||
|
inputs=[init_img],
|
||||||
|
outputs=[img2img_prompt],
|
||||||
|
)
|
||||||
|
|
||||||
save.click(
|
save.click(
|
||||||
fn=wrap_gradio_call(save_files),
|
fn=wrap_gradio_call(save_files),
|
||||||
_js="(x, y, z) => [x, y, selected_gallery_index()]",
|
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||||
inputs=[
|
inputs=[
|
||||||
generation_info,
|
generation_info,
|
||||||
img2img_gallery,
|
img2img_gallery,
|
||||||
html_info
|
do_make_zip,
|
||||||
|
html_info,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
download_files,
|
||||||
html_info,
|
html_info,
|
||||||
html_info,
|
html_info,
|
||||||
html_info,
|
html_info,
|
||||||
@ -903,7 +962,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
|
|
||||||
extras_send_to_inpaint.click(
|
extras_send_to_inpaint.click(
|
||||||
fn=lambda x: image_from_url_text(x),
|
fn=lambda x: image_from_url_text(x),
|
||||||
_js="extract_image_from_gallery_img2img",
|
_js="extract_image_from_gallery_inpaint",
|
||||||
inputs=[result_images],
|
inputs=[result_images],
|
||||||
outputs=[init_img_with_mask],
|
outputs=[init_img_with_mask],
|
||||||
)
|
)
|
||||||
@ -939,7 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
custom_name = gr.Textbox(label="Custom Name (Optional)")
|
custom_name = gr.Textbox(label="Custom Name (Optional)")
|
||||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
|
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
|
||||||
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
||||||
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
|
save_as_half = gr.Checkbox(value=False, label="Save as float16")
|
||||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
||||||
|
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
@ -983,11 +1042,13 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
|
|
||||||
process_src = gr.Textbox(label='Source directory')
|
process_src = gr.Textbox(label='Source directory')
|
||||||
process_dst = gr.Textbox(label='Destination directory')
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
|
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
|
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_flip = gr.Checkbox(label='Flip')
|
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||||
process_split = gr.Checkbox(label='Split into two')
|
process_split = gr.Checkbox(label='Split oversized images into two')
|
||||||
process_caption = gr.Checkbox(label='Add caption')
|
process_caption = gr.Checkbox(label='Use BLIP caption as filename')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
@ -997,14 +1058,17 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||||
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
|
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
|
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
|
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
|
||||||
@ -1056,6 +1120,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
inputs=[
|
inputs=[
|
||||||
process_src,
|
process_src,
|
||||||
process_dst,
|
process_dst,
|
||||||
|
process_width,
|
||||||
|
process_height,
|
||||||
process_flip,
|
process_flip,
|
||||||
process_split,
|
process_split,
|
||||||
process_caption,
|
process_caption,
|
||||||
@ -1074,7 +1140,10 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
learn_rate,
|
learn_rate,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
|
training_width,
|
||||||
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
num_repeats,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
template_file,
|
template_file,
|
||||||
@ -1138,6 +1207,15 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
component_dict = {}
|
component_dict = {}
|
||||||
|
|
||||||
def open_folder(f):
|
def open_folder(f):
|
||||||
|
if not os.path.isdir(f):
|
||||||
|
print(f"""
|
||||||
|
WARNING
|
||||||
|
An open_folder request was made with an argument that is not a folder.
|
||||||
|
This could be an error or a malicious attempt to run code on your computer.
|
||||||
|
Requested path was: {f}
|
||||||
|
""", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
if not shared.cmd_opts.hide_ui_dir_config:
|
if not shared.cmd_opts.hide_ui_dir_config:
|
||||||
path = os.path.normpath(f)
|
path = os.path.normpath(f)
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
@ -1151,10 +1229,13 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
changed = 0
|
changed = 0
|
||||||
|
|
||||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
||||||
if not opts.same_type(value, opts.data_labels[key].default):
|
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
|
||||||
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
|
||||||
|
|
||||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
||||||
|
if comp == dummy_component:
|
||||||
|
continue
|
||||||
|
|
||||||
comp_args = opts.data_labels[key].component_args
|
comp_args = opts.data_labels[key].component_args
|
||||||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||||
continue
|
continue
|
||||||
@ -1172,6 +1253,21 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
|
|
||||||
return f'{changed} settings changed.', opts.dumpjson()
|
return f'{changed} settings changed.', opts.dumpjson()
|
||||||
|
|
||||||
|
def run_settings_single(value, key):
|
||||||
|
if not opts.same_type(value, opts.data_labels[key].default):
|
||||||
|
return gr.update(visible=True), opts.dumpjson()
|
||||||
|
|
||||||
|
oldval = opts.data.get(key, None)
|
||||||
|
opts.data[key] = value
|
||||||
|
|
||||||
|
if oldval != value:
|
||||||
|
if opts.data_labels[key].onchange is not None:
|
||||||
|
opts.data_labels[key].onchange()
|
||||||
|
|
||||||
|
opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
return gr.update(value=value), opts.dumpjson()
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
||||||
settings_submit = gr.Button(value="Apply settings", variant='primary')
|
settings_submit = gr.Button(value="Apply settings", variant='primary')
|
||||||
result = gr.HTML()
|
result = gr.HTML()
|
||||||
@ -1179,6 +1275,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
settings_cols = 3
|
settings_cols = 3
|
||||||
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
|
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
|
||||||
|
|
||||||
|
quicksettings_list = []
|
||||||
|
|
||||||
cols_displayed = 0
|
cols_displayed = 0
|
||||||
items_displayed = 0
|
items_displayed = 0
|
||||||
previous_section = None
|
previous_section = None
|
||||||
@ -1201,10 +1299,14 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
|
|
||||||
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
|
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
|
||||||
|
|
||||||
component = create_setting_component(k)
|
if item.show_on_main_page:
|
||||||
component_dict[k] = component
|
quicksettings_list.append((i, k, item))
|
||||||
components.append(component)
|
components.append(dummy_component)
|
||||||
items_displayed += 1
|
else:
|
||||||
|
component = create_setting_component(k)
|
||||||
|
component_dict[k] = component
|
||||||
|
components.append(component)
|
||||||
|
items_displayed += 1
|
||||||
|
|
||||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||||
request_notifications.click(
|
request_notifications.click(
|
||||||
@ -1218,7 +1320,6 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
|
||||||
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
|
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
|
||||||
|
|
||||||
|
|
||||||
def reload_scripts():
|
def reload_scripts():
|
||||||
modules.scripts.reload_script_body_only()
|
modules.scripts.reload_script_body_only()
|
||||||
|
|
||||||
@ -1265,12 +1366,16 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
css += css_hide_progressbar
|
css += css_hide_progressbar
|
||||||
|
|
||||||
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||||
|
with gr.Row(elem_id="quicksettings"):
|
||||||
|
for i, k, item in quicksettings_list:
|
||||||
|
component = create_setting_component(k)
|
||||||
|
component_dict[k] = component
|
||||||
|
|
||||||
settings_interface.gradio_ref = demo
|
settings_interface.gradio_ref = demo
|
||||||
|
|
||||||
with gr.Tabs() as tabs:
|
with gr.Tabs() as tabs:
|
||||||
for interface, label, ifid in interfaces:
|
for interface, label, ifid in interfaces:
|
||||||
with gr.TabItem(label, id=ifid):
|
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
|
||||||
interface.render()
|
interface.render()
|
||||||
|
|
||||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||||
@ -1283,6 +1388,15 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
outputs=[result, text_settings],
|
outputs=[result, text_settings],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for i, k, item in quicksettings_list:
|
||||||
|
component = component_dict[k]
|
||||||
|
|
||||||
|
component.change(
|
||||||
|
fn=lambda value, k=k: run_settings_single(value, key=k),
|
||||||
|
inputs=[component],
|
||||||
|
outputs=[component, text_settings],
|
||||||
|
)
|
||||||
|
|
||||||
def modelmerger(*args):
|
def modelmerger(*args):
|
||||||
try:
|
try:
|
||||||
results = modules.extras.run_modelmerger(*args)
|
results = modules.extras.run_modelmerger(*args)
|
||||||
|
@ -36,10 +36,11 @@ class Upscaler:
|
|||||||
self.half = not modules.shared.cmd_opts.no_half
|
self.half = not modules.shared.cmd_opts.no_half
|
||||||
self.pre_pad = 0
|
self.pre_pad = 0
|
||||||
self.mod_scale = None
|
self.mod_scale = None
|
||||||
if self.name is not None and create_dirs:
|
|
||||||
|
if self.model_path is None and self.name:
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
if not os.path.exists(self.model_path):
|
if self.model_path and create_dirs:
|
||||||
os.makedirs(self.model_path)
|
os.makedirs(self.model_path, exist_ok=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
|
23
script.js
23
script.js
@ -6,6 +6,10 @@ function get_uiCurrentTab() {
|
|||||||
return gradioApp().querySelector('.tabs button:not(.border-transparent)')
|
return gradioApp().querySelector('.tabs button:not(.border-transparent)')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function get_uiCurrentTabContent() {
|
||||||
|
return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
|
||||||
|
}
|
||||||
|
|
||||||
uiUpdateCallbacks = []
|
uiUpdateCallbacks = []
|
||||||
uiTabChangeCallbacks = []
|
uiTabChangeCallbacks = []
|
||||||
let uiCurrentTab = null
|
let uiCurrentTab = null
|
||||||
@ -40,6 +44,25 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
|
mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a ctrl+enter as a shortcut to start a generation
|
||||||
|
*/
|
||||||
|
document.addEventListener('keydown', function(e) {
|
||||||
|
var handled = false;
|
||||||
|
if (e.key !== undefined) {
|
||||||
|
if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true;
|
||||||
|
} else if (e.keyCode !== undefined) {
|
||||||
|
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true;
|
||||||
|
}
|
||||||
|
if (handled) {
|
||||||
|
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
||||||
|
if (button) {
|
||||||
|
button.click();
|
||||||
|
}
|
||||||
|
e.preventDefault();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* checks that a UI element is not in another hidden element or tab content
|
* checks that a UI element is not in another hidden element or tab content
|
||||||
*/
|
*/
|
||||||
|
@ -10,7 +10,6 @@ from modules.processing import Processed, process_images
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
def title(self):
|
def title(self):
|
||||||
return "Prompts from file or textbox"
|
return "Prompts from file or textbox"
|
||||||
@ -29,6 +28,9 @@ class Script(scripts.Script):
|
|||||||
checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt])
|
checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt])
|
||||||
return [checkbox_txt, file, prompt_txt]
|
return [checkbox_txt, file, prompt_txt]
|
||||||
|
|
||||||
|
def on_show(self, checkbox_txt, file, prompt_txt):
|
||||||
|
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
|
||||||
|
|
||||||
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
|
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
|
||||||
if (checkbox_txt):
|
if (checkbox_txt):
|
||||||
lines = [x.strip() for x in prompt_txt.splitlines()]
|
lines = [x.strip() for x in prompt_txt.splitlines()]
|
||||||
|
@ -10,8 +10,8 @@ import numpy as np
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images
|
from modules import images, hypernetwork
|
||||||
from modules.processing import process_images, Processed
|
from modules.processing import process_images, Processed, get_correct_sampler
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.sd_samplers
|
import modules.sd_samplers
|
||||||
@ -56,15 +56,17 @@ def apply_order(p, x, xs):
|
|||||||
p.prompt = prompt_tmp + p.prompt
|
p.prompt = prompt_tmp + p.prompt
|
||||||
|
|
||||||
|
|
||||||
samplers_dict = {}
|
def build_samplers_dict(p):
|
||||||
for i, sampler in enumerate(modules.sd_samplers.samplers):
|
samplers_dict = {}
|
||||||
samplers_dict[sampler.name.lower()] = i
|
for i, sampler in enumerate(get_correct_sampler(p)):
|
||||||
for alias in sampler.aliases:
|
samplers_dict[sampler.name.lower()] = i
|
||||||
samplers_dict[alias.lower()] = i
|
for alias in sampler.aliases:
|
||||||
|
samplers_dict[alias.lower()] = i
|
||||||
|
return samplers_dict
|
||||||
|
|
||||||
|
|
||||||
def apply_sampler(p, x, xs):
|
def apply_sampler(p, x, xs):
|
||||||
sampler_index = samplers_dict.get(x.lower(), None)
|
sampler_index = build_samplers_dict(p).get(x.lower(), None)
|
||||||
if sampler_index is None:
|
if sampler_index is None:
|
||||||
raise RuntimeError(f"Unknown sampler: {x}")
|
raise RuntimeError(f"Unknown sampler: {x}")
|
||||||
|
|
||||||
@ -78,7 +80,11 @@ def apply_checkpoint(p, x, xs):
|
|||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(p, x, xs):
|
def apply_hypernetwork(p, x, xs):
|
||||||
shared.hypernetwork = shared.hypernetworks.get(x, None)
|
hypernetwork.load_hypernetwork(x)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_clip_skip(p, x, xs):
|
||||||
|
opts.data["CLIP_stop_at_last_layers"] = x
|
||||||
|
|
||||||
|
|
||||||
def format_value_add_label(p, opt, x):
|
def format_value_add_label(p, opt, x):
|
||||||
@ -132,6 +138,7 @@ axis_options = [
|
|||||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
|
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
|
||||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
|
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
|
||||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
|
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
|
||||||
|
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label),
|
||||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -142,7 +149,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
|
|||||||
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
||||||
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
|
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
|
||||||
|
|
||||||
first_pocessed = None
|
first_processed = None
|
||||||
|
|
||||||
state.job_count = len(xs) * len(ys) * p.n_iter
|
state.job_count = len(xs) * len(ys) * p.n_iter
|
||||||
|
|
||||||
@ -151,8 +158,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
|
|||||||
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
|
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
|
||||||
|
|
||||||
processed = cell(x, y)
|
processed = cell(x, y)
|
||||||
if first_pocessed is None:
|
if first_processed is None:
|
||||||
first_pocessed = processed
|
first_processed = processed
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res.append(processed.images[0])
|
res.append(processed.images[0])
|
||||||
@ -163,9 +170,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
|
|||||||
if draw_legend:
|
if draw_legend:
|
||||||
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
|
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
|
||||||
|
|
||||||
first_pocessed.images = [grid]
|
first_processed.images = [grid]
|
||||||
|
|
||||||
return first_pocessed
|
return first_processed
|
||||||
|
|
||||||
|
|
||||||
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
||||||
@ -195,10 +202,14 @@ class Script(scripts.Script):
|
|||||||
return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds]
|
return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds]
|
||||||
|
|
||||||
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds):
|
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds):
|
||||||
modules.processing.fix_seed(p)
|
if not no_fixed_seeds:
|
||||||
p.batch_size = 1
|
modules.processing.fix_seed(p)
|
||||||
|
|
||||||
initial_hn = shared.hypernetwork
|
if not opts.return_grid:
|
||||||
|
p.batch_size = 1
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
|
||||||
|
|
||||||
def process_axis(opt, vals):
|
def process_axis(opt, vals):
|
||||||
if opt.label == 'Nothing':
|
if opt.label == 'Nothing':
|
||||||
@ -213,7 +224,6 @@ class Script(scripts.Script):
|
|||||||
m = re_range.fullmatch(val)
|
m = re_range.fullmatch(val)
|
||||||
mc = re_range_count.fullmatch(val)
|
mc = re_range_count.fullmatch(val)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
|
|
||||||
start = int(m.group(1))
|
start = int(m.group(1))
|
||||||
end = int(m.group(2))+1
|
end = int(m.group(2))+1
|
||||||
step = int(m.group(3)) if m.group(3) is not None else 1
|
step = int(m.group(3)) if m.group(3) is not None else 1
|
||||||
@ -256,6 +266,17 @@ class Script(scripts.Script):
|
|||||||
|
|
||||||
valslist = [opt.type(x) for x in valslist]
|
valslist = [opt.type(x) for x in valslist]
|
||||||
|
|
||||||
|
# Confirm options are valid before starting
|
||||||
|
if opt.label == "Sampler":
|
||||||
|
samplers_dict = build_samplers_dict(p)
|
||||||
|
for sampler_val in valslist:
|
||||||
|
if sampler_val.lower() not in samplers_dict.keys():
|
||||||
|
raise RuntimeError(f"Unknown sampler: {sampler_val}")
|
||||||
|
elif opt.label == "Checkpoint name":
|
||||||
|
for ckpt_val in valslist:
|
||||||
|
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
|
||||||
|
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
|
||||||
|
|
||||||
return valslist
|
return valslist
|
||||||
|
|
||||||
x_opt = axis_options[x_type]
|
x_opt = axis_options[x_type]
|
||||||
@ -307,6 +328,8 @@ class Script(scripts.Script):
|
|||||||
# restore checkpoint in case it was changed by axes
|
# restore checkpoint in case it was changed by axes
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model)
|
modules.sd_models.reload_model_weights(shared.sd_model)
|
||||||
|
|
||||||
shared.hypernetwork = initial_hn
|
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
||||||
|
|
||||||
|
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
69
style.css
69
style.css
@ -1,3 +1,7 @@
|
|||||||
|
.container {
|
||||||
|
max-width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
.output-html p {margin: 0 0.5em;}
|
.output-html p {margin: 0 0.5em;}
|
||||||
|
|
||||||
.row > *,
|
.row > *,
|
||||||
@ -103,7 +107,12 @@
|
|||||||
|
|
||||||
#style_apply, #style_create, #interrogate{
|
#style_apply, #style_create, #interrogate{
|
||||||
margin: 0.75em 0.25em 0.25em 0.25em;
|
margin: 0.75em 0.25em 0.25em 0.25em;
|
||||||
min-width: 3em;
|
min-width: 5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
#style_apply, #style_create, #deepbooru{
|
||||||
|
margin: 0.75em 0.25em 0.25em 0.25em;
|
||||||
|
min-width: 5em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#style_pos_col, #style_neg_col{
|
#style_pos_col, #style_neg_col{
|
||||||
@ -393,10 +402,20 @@ input[type="range"]{
|
|||||||
|
|
||||||
#txt2img_interrupt, #img2img_interrupt{
|
#txt2img_interrupt, #img2img_interrupt{
|
||||||
position: absolute;
|
position: absolute;
|
||||||
width: 100%;
|
width: 50%;
|
||||||
height: 72px;
|
height: 72px;
|
||||||
background: #b4c0cc;
|
background: #b4c0cc;
|
||||||
border-radius: 8px;
|
border-radius: 0px;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
#txt2img_skip, #img2img_skip{
|
||||||
|
position: absolute;
|
||||||
|
width: 50%;
|
||||||
|
right: 0px;
|
||||||
|
height: 72px;
|
||||||
|
background: #b4c0cc;
|
||||||
|
border-radius: 0px;
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -411,3 +430,47 @@ input[type="range"]{
|
|||||||
#img2img_image div.h-60{
|
#img2img_image div.h-60{
|
||||||
height: 480px;
|
height: 480px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#context-menu{
|
||||||
|
z-index:9999;
|
||||||
|
position:absolute;
|
||||||
|
display:block;
|
||||||
|
padding:0px 0;
|
||||||
|
border:2px solid #a55000;
|
||||||
|
border-radius:8px;
|
||||||
|
box-shadow:1px 1px 2px #CE6400;
|
||||||
|
width: 200px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.context-menu-items{
|
||||||
|
list-style: none;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.context-menu-items a{
|
||||||
|
display:block;
|
||||||
|
padding:5px;
|
||||||
|
cursor:pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.context-menu-items a:hover{
|
||||||
|
background: #a55000;
|
||||||
|
}
|
||||||
|
|
||||||
|
#quicksettings > div{
|
||||||
|
border: none;
|
||||||
|
background: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
#quicksettings > div > div{
|
||||||
|
max-width: 32em;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
canvas[key="mask"] {
|
||||||
|
z-index: 12 !important;
|
||||||
|
filter: invert();
|
||||||
|
mix-blend-mode: multiply;
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 526 KiB After Width: | Height: | Size: 329 KiB |
10
webui.py
10
webui.py
@ -5,6 +5,8 @@ import importlib
|
|||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
from modules import devices, sd_samplers
|
from modules import devices, sd_samplers
|
||||||
@ -58,6 +60,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||||||
shared.state.current_latent = None
|
shared.state.current_latent = None
|
||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
shared.state.current_image_sampling_step = 0
|
shared.state.current_image_sampling_step = 0
|
||||||
|
shared.state.skipped = False
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
shared.state.textinfo = None
|
shared.state.textinfo = None
|
||||||
|
|
||||||
@ -88,6 +91,9 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
|||||||
shared.sd_model = modules.sd_models.load_model()
|
shared.sd_model = modules.sd_models.load_model()
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||||
|
|
||||||
|
loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)
|
||||||
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
@ -101,7 +107,7 @@ def webui():
|
|||||||
|
|
||||||
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||||
|
|
||||||
demo.launch(
|
app,local_url,share_url = demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
server_name="0.0.0.0" if cmd_opts.listen else None,
|
server_name="0.0.0.0" if cmd_opts.listen else None,
|
||||||
server_port=cmd_opts.port,
|
server_port=cmd_opts.port,
|
||||||
@ -111,6 +117,8 @@ def webui():
|
|||||||
prevent_thread_lock=True
|
prevent_thread_lock=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.add_middleware(GZipMiddleware,minimum_size=1000)
|
||||||
|
|
||||||
while 1:
|
while 1:
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
if getattr(demo, 'do_restart', False):
|
if getattr(demo, 'do_restart', False):
|
||||||
|
Loading…
Reference in New Issue
Block a user