#!/usr/bin/env python# coding: utf-8'''Directly denoising Diffusion model, the Simplest deep learning.''''''Version 1.0''''''python diff.py'''__credits__=["iBoxDB","Bruce Yang CL-N","2025-5"]# 3rd parts#https://pytorch.org CPU version onlyimporttorchimporttorch.nnasnnimporttorch.utils.dataastdatafromtorch.optimimportAdam,AdamWfromtorchvision.datasetsimportMNIST,CIFAR10fromtorchvision.utilsimportsave_image,make_gridimporttorchvision.transformsastransformsimporttorchvision.transforms.functionalasTFimportmatplotlib.pyplotaspltimportnumpyasnpimportmathimportosth=torchDataLoader=tdata.DataLoader'''if over trained, remove the *.pt files.set 'epochs' to 20and re-train several times till best results.it using random noisy, results varying.'''epochs=60#0 train_batch_size=36th.set_default_dtype(th.float32)th.set_num_threads(4)dataset_path='~/datasets'#download to hereprint(os.path.expanduser(dataset_path))dataset='MNIST'img_size=(1,28,28)img_size=(1,16,16)transform=transforms.Compose([transforms.Resize((img_size[1],img_size[2]),transforms.InterpolationMode.BILINEAR),transforms.ToTensor(),])train_dataset=MNIST(dataset_path,transform=transform,train=True,download=True)#test_dataset = MNIST(dataset_path, transform=transform, train=False, download=True)print(len(train_dataset))generator1=torch.Generator().manual_seed(69981)train_dataset,_=tdata.random_split(train_dataset,[0.01,0.99],generator1)'''train_dataset = [x for x in train_dataset if x[1] == 9 ]train_dataset = train_dataset[0:train_batch_size*2]'''train_dataset=[xforxintrain_dataset]print(len(train_dataset))defdraw_sample_image(x,postfix,block=True):plt.close('all')plt.figure(figsize=(5,5))plt.axis("off")plt.title(postfix)im=make_grid(x.detach().cpu(),nrow=int(math.sqrt(len(x))),scale_each=True,normalize=False)im=TF.resize(im,(im.size(1),im.size(2)))im=np.transpose(im,(1,2,0))plt.imshow(im)plt.show(block=block)ifnotblock:plt.pause(3)plt.close('all')classDenoiser(nn.Module):def__init__(self):super(Denoiser,self).__init__()C,H,W=img_sizeself.unet=nn.Sequential(nn.Conv2d(C,32,4,2,1),nn.InstanceNorm2d(32),nn.Conv2d(32,64,4,2,1),nn.InstanceNorm2d(64),nn.Conv2d(64,128,1),nn.SiLU(),nn.Conv2d(128,128,2),nn.LeakyReLU(),nn.Conv2d(128,256,2),nn.LeakyReLU(),nn.ConvTranspose2d(256,128,2),nn.LeakyReLU(),nn.ConvTranspose2d(128,64,2),nn.LeakyReLU(),nn.ConvTranspose2d(64,32,4,2,1),nn.LeakyReLU(),nn.ConvTranspose2d(32,C,4,2,1),)defforward(self,x):returnself.bruce_forward(x)defbruce_forward(self,x):res=self.unet(x)returnx-resclassDiffusion(nn.Module):def__init__(self):super(Diffusion,self).__init__()self.model=Denoiser()defscale_to_minus_one_to_one(self,x):returnx*2-1defreverse_scale_to_zero_to_one(self,x):return(x+1)/2defbruce_noisy(self,x_zeros,ranLen=31):x_zeros=x_zeros.detach()x_zeros=self.scale_to_minus_one_to_one(x_zeros)rs=[]es=[]for_inrange(1):target=torch.rand_like(x_zeros)target=self.scale_to_minus_one_to_one(target)alpha=20/100epsilon=target-x_zeros*alphars.append(target)es.append(epsilon)for_inrange(ranLen-1):alpha=th.randint(21,100,(1,)).item()/100epsilon=torch.rand_like(x_zeros)epsilon=self.scale_to_minus_one_to_one(epsilon)epsilon=epsilon*(1-alpha)noisy_sample=x_zeros*alpha+epsilonrs.append(noisy_sample)es.append(epsilon)returnrs,es@th.no_grad()defsample(self,time=64):target=torch.rand(img_size)target=self.scale_to_minus_one_to_one(target)rs=[]target=target.unsqueeze(0)foralphainrange(20,20+time):epsilon=self.model(target).detach()ifalpha==20:passalpha=alpha/100x_zeros=(target-epsilon)/(alpha)a=x_zeros.squeeze(0)a=a.clamp(-1,1)a=self.reverse_scale_to_zero_to_one(a)rs.append(a)alpha+=0.01epsilon=torch.rand_like(x_zeros)epsilon=self.scale_to_minus_one_to_one(epsilon)epsilon=epsilon*(1-alpha)epsilon=epsilon*0.5target=x_zeros*alpha+epsilontarget=target.clamp(-5,5)returnth.stack(rs)defforward(self,x,y):inputX=[]inputY=[]fora,_inzip(x,y):rs,es=self.bruce_noisy(a)inputX+=rsinputY+=esinputX=th.stack(inputX)inputY=th.stack(inputY)returnself.model(inputX),inputYdenoising_loss=nn.MSELoss()diffusion=Diffusion()lr=0.01optimizer=AdamW(diffusion.parameters(),lr=lr)ifos.path.exists("diff.pt"):a=th.load("diff.pt")try:diffusion.load_state_dict(a["d"])optimizer.load_state_dict(a["o"])print("load diff.pt")exceptExceptionase:print(e)train_loader=DataLoader(dataset=train_dataset,batch_size=train_batch_size,shuffle=True,)forx,yintrain_loader:x=x[0:36]print(th.min(x),th.max(x))#draw_sample_image(x,"Show")x=x[0]rs,es=diffusion.bruce_noisy(x,36)x=th.stack(rs)print(th.min(x),th.max(x))x=(x+1)/2#draw_sample_image(x,"Noisy")x=th.stack(es)print(th.min(x),th.max(x))x*=(1/1.21)x=(x+1)/2#draw_sample_image(x,"De Noisy")breakcount_loader=len(train_loader)defcount_parameters(model):returnsum(p.numel()forpinmodel.parameters()ifp.requires_grad)print("Model Parameters: ",count_parameters(diffusion),count_loader)defshow_samples(time=70,block=True):es=[]forlinrange(64):x=diffusion.sample(time)[-1]es.append(x)x=th.stack(es)draw_sample_image(x,"Samples",block)diffusion.train()forepochinrange(epochs):noise_prediction_loss=0train_loader=DataLoader(dataset=train_dataset,batch_size=train_batch_size,shuffle=True,)forbatch_idx,(x,y)inenumerate(train_loader):optimizer.zero_grad()x,y=diffusion(x,y)loss=denoising_loss(x.view(-1),y.view(-1))noise_prediction_loss+=loss.item()loss.backward()optimizer.step()print(f"{batch_idx} / {count_loader}.",loss.item())noise_prediction_loss=noise_prediction_loss/count_loaderprint("Epoch",epoch+1,f"/ {epochs} complete."," L: ",noise_prediction_loss)a={"d":diffusion.state_dict(),"o":optimizer.state_dict()}th.save(a,"diff.pt")print("save diff.pt")ifepoch%10==1:show_samples(70,False)ifnoise_prediction_loss<0.005:print(epoch+1," Goto Eval, remove diff.pt before re-train ")breakdiffusion.eval()for_inrange(2):show_samples()forlinrange(10):x=diffusion.sample(81)print(th.min(x),th.max(x))draw_sample_image(x,"Sample Single")print("End.")
Last edit: iBoxDB 2025-06-01
Last edit: iBoxDB 2025-06-01
ShortLink Directly denoising Diffusion model, the Simplest deep learning.