iBoxDB - 2025-04-29
#!/usr/bin/env python
# coding: utf-8

'''Mutant GAN generative adversarial network The simplest deep learning'''
'''Version 1.0'''
'''python mutant_gan.py'''

__credits__ = ["iBoxDB", "Bruce Yang CL-N", "2025-4"]



# 3rd parts
#https://pytorch.org  CPU version only
import torch
# pip install matplotlib
import matplotlib.pyplot as plt
# pip install pillow
from PIL import Image, ImageDraw, ImageFont

import torch.utils.data as data
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os 

th = torch
nn = th.nn
optim = th.optim 


th.set_num_threads(4)
th.set_default_dtype(th.float64) 

# Zero = Empty Space
char_list = list(" W1234567890QERTYUIOPASDFGHJKLZXCVBNM")
max_seq = 2 
font_size_resize = (32,32)
RGB = 1 

class MGenerator(nn.Module):
    def __init__(self):
        super().__init__()

        self.true_input =  nn.Sequential(
            nn.Embedding(len(char_list)*2, 16), 
            nn.Unflatten(-1,(4,4)),
            nn.Unfold((4,4)),
            nn.Unflatten(-1,(1,1)),
        ) 
        self.mu =  nn.Sequential( 
            #(N,32,1,1)     
            nn.ConvTranspose2d(32,64, 4, 1, 0,bias=False),
            nn.LazyBatchNorm2d(), 
            nn.Conv2d(64,32, 4, 1, 0,bias=False),         
        )
        self.main = nn.Sequential( 

            nn.LazyBatchNorm2d(),
            #(N,32,1,1)     
            nn.ConvTranspose2d(32,128, 4, 1, 0,bias=False),
            nn.LazyBatchNorm2d(),
            nn.ReLU(True),

            nn.ConvTranspose2d(128,64,4, 2, 1,bias=False),
            nn.LazyBatchNorm2d(),
            nn.ReLU(True),

            nn.ConvTranspose2d(64,32,4, 2, 1,bias=False),
            nn.LazyBatchNorm2d(),
            nn.ReLU(True),

            nn.ConvTranspose2d(32,RGB, 4, 2, 1,bias=False),
            #(N, 1, 32, 32)
        )

    def forward(self, input, mutant=0):
        input = self.true_input(input)  

        s = th.ones((1,1,1,1))
        if mutant != 0:
            s = th.linspace(0,th.pi*2,32) * mutant
            s = th.sin(s) + mutant
            s = s.view(1,32,1,1)
            s = self.mu( input + s )  
            input = s 

        r = self.main(input)        
        r = th.sigmoid(r/10)
        return r,s

a = th.randint(10,20,(5,2))
g = MGenerator()
b,mu = g(a)
print(b.shape)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.main = nn.Sequential( 

            nn.Conv2d(RGB,32, 4, 2, 1,bias=False),
            nn.LazyBatchNorm2d(),  
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32,64,4, 2, 1,bias=False), 
            nn.LazyBatchNorm2d(),   
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64,128,4, 2, 1,bias=False), 
            nn.LazyBatchNorm2d(), 
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128,32, 4, 2, 0,bias=False), 
        )

        self.true_ouput = nn.Sequential( 
            nn.LazyBatchNorm2d(),      
            nn.Conv2d(32,1, 1,bias=False),               
            nn.Flatten(), 
        )   

    def forward(self, input,mutant):
        r = self.main(input)
        r = self.true_ouput(r)
        r = th.sigmoid(r)
        return r

d = Discriminator()
b = d(b,mu)
print(b.shape)


class ImgDataset(data.Dataset):
    def __init__(self):
        super().__init__() 
        self.lst = [] 
        self.lst_t = []   
        font = ImageFont.load_default_imagefont()
        H = font.getlength("W")
        print("Img Font Size:", H)
        H *= 2
        self.hash = th.empty( (len(char_list),len(char_list),1,*font_size_resize) )
        for i in range(len(char_list)):
            for j in range(len(char_list)):
                img = Image.new("L", (H*2,H) ,(0,))
                id = ImageDraw.Draw(img)
                id.fontmode = "1"
                t = char_list[i] + ' ' + char_list[j]
                id.text((2,0),t,(255,), font=font) 
                img = img.resize(font_size_resize,Image.Resampling.NEAREST)
                p = np.array(img)
                p = th.tensor(p, dtype=th.float64)
                p = p.unsqueeze(0) / 255
                self.lst.append(p)
                self.lst_t.append(th.tensor((i,j), dtype=th.int32))
                self.hash[i,j] = p

    def __len__(self):
        return len(self.lst)

    def __getitem__(self, index): 
        return self.lst_t[index], self.lst[index]

    def gethash(self,labs):
        r = [ self.hash[k[0],k[1]] for k in labs]
        return r

    def show(self, imgs=None, repeat=True, pause=5):
        if imgs is None:
            imgs = [self.lst,self.lst,self.lst] 

        self.fig = plt.figure(1,figsize=(12,4))
        ax0,ax1,ax2 = self.fig.subplots(1,3)
        ax0.set_title("Creating")
        ax1.set_title("Learning")
        ax2.set_title("Image")
        plt.axis("off") 
        ims = [[ax0.imshow(i[0].squeeze(0).numpy(), cmap="gray" ,animated=True),
                ax1.imshow(i[1].squeeze(0).numpy(), cmap="gray" ,animated=True),
                ax2.imshow(i[2].squeeze(0).numpy(), cmap="gray" ,animated=True),] for i in imgs]
        ani = animation.ArtistAnimation(self.fig, ims, interval=200, repeat_delay=200)

        plt.show(block=repeat)    
        plt.pause(pause)
        plt.close('all')  



dataloader = data.DataLoader(ImgDataset(), batch_size=128,shuffle=True)
print(len(dataloader.dataset))
#dataloader.dataset.show()

g_mutant = 1
@th.no_grad
def testGenerator(num,repeat):
    labs = th.randint(0,len(char_list),(num,2))     
    a,m = netG( labs+len(char_list), g_mutant )  
    b,m = netG( labs )   
    c = dataloader.dataset.gethash(labs)
    d = list(zip(a.detach(),b.detach(),c)) 
    dataloader.dataset.show(d ,repeat, )



criterion = nn.BCELoss()
gloss = nn.MSELoss() 
real_label = 1.
fake_label = 0.
lr = 0.01
netG = MGenerator()
netD = Discriminator()

if os.path.exists("netG.pt") :
    netG.load_state_dict(th.load("netG.pt"))
    print("load ","netG.pt")
if os.path.exists("netD.pt") :
    netD.load_state_dict(th.load("netD.pt"))
    print("load ","netD.pt")

optimizerD = optim.AdamW(netD.parameters(), lr=lr)
optimizerG = optim.AdamW(netG.parameters(), lr=lr)
netG.train()
netD.train()

train_time_total = 50
for train_time in range(train_time_total):
    print(f"{train_time} / {train_time_total} Mutant-GAN")
    for labs, imgs in dataloader:

        netG.zero_grad()
        fake,mu = netG(labs)
        errG = gloss(fake.view(-1),imgs.view(-1))
        errG.backward()
        optimizerG.step()

        netD.zero_grad()
        fake,mu = netG(labs)
        label = torch.full((len(labs),), real_label)   
        output = netD(fake,mu).view(-1) 
        errD_real = criterion(output, label) 
        errD_real.backward()
        optimizerD.step()


        #These letters Beyond the char_list,
        #See how It Create new characters. 
        mutant_labs = labs+len(char_list)
        fake,mu = netG(mutant_labs,g_mutant)

        netD.zero_grad()
        output = netD(fake.detach(),mu.detach()).view(-1) 
        label.fill_(fake_label)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        optimizerD.step()

        netG.zero_grad()    
        output = netD(fake,mu).view(-1) 
        label.fill_(real_label)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        print(errD_real.item(), errD_fake.item(), errG.item())

    th.save(netG.state_dict(),"netG.pt")
    th.save(netD.state_dict(),"netD.pt")
    print("save ",train_time)

    if  train_time % 2 == 1:
        testGenerator(30,False)

print("ShowUI")
testGenerator(25*25,True)
print("End.")