In this post we will train a GAN model to map Age progression in faces- i.e generate an aged photo of a person given a photo at younger age (and also the reverse). We will use CycleGANs for this.


Old to young sample 3


CycleGANs were introduced in this paper titled Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks where the authors presented an approach for learning to translate an image from a source domain X to a target domain Y in the absence of paired examples.

For the images of faces with various ages we will be using the UTKFace dataset which has a cropped image set of only faces marked with age , gender , race , etc.

We will be using following two good references that use CycleGAN in order to build and train our models

  1. https://github.com/sungnam0/Face-Aging-with-CycleGAN 2.https://machinelearningmastery.com/cyclegan-tutorial-with-keras/

Import necessary modules

import numpy as np
import keras
import tensorflow as tf
from keras.layers import Input,Conv2D,MaxPooling2D,Flatten,Activation,BatchNormalization,K,UpSampling2D 
from keras.layers import Dropout,GlobalAveragePooling2D,LeakyReLU,Dense,Reshape, concatenate,Conv2DTranspose
from keras.models import Model,load_model
import matplotlib.pyplot as plt
#import keras.backend as K



import os
import time
from datetime import datetime


from keras.applications import InceptionResNetV2
from keras.callbacks import TensorBoard

from keras.optimizers import Adam
from keras.utils import to_categorical
from keras_preprocessing import image

from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed


import pandas as pd

import os

from matplotlib import pyplot

from numpy import load

from random import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from keras.optimizers import Adam

Set tf backend config to allocate memory as needed instead of pre-allocating

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

# Create a session with the above options specified.
keras.backend.tensorflow_backend.set_session(tf.Session(config=config))

Mount google drive to retrieve and store files used in this project

from google.colab import drive
drive.mount('/content/drive', force_remount=True)
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive
!ls -al '/content/drive/My Drive/FaceGAN/'
total 105354
drwx------ 2 root root      4096 Aug 22 10:59 results
-rw------- 1 root root   1239467 Aug 22 09:56 utk_data.csv
drwx------ 2 root root      4096 Aug 21 08:41 UTKFace
-rw------- 1 root root 106634631 Aug 21 07:15 UTKFace.tar.gz

Get the UTKFace dataset

!tar zxf '/content/drive/My Drive/FaceGAN/UTKFace.tar.gz' UTKFace

Parse the data

The labels of each face image is embedded in the file name, formated like [age][gender][race]_[date&time].jpg

[age] is an integer from 0 to 116, indicating the age
[gender] is either 0 (male) or 1 (female)
[race] is an integer from 0 to 4, denoting White, Black, Asian, Indian, and Others (like Hispanic, Latino, Middle Eastern).
[date&time] is in the format of yyyymmddHHMMSSFFF, showing the date and time an image was collected to UTKFace
data=[]

for filename in os.listdir('./UTKFace'):
  parts=filename.split('_')
  #print(parts[3])
  item={}
  item['image']=filename
  item['age']=parts[0]
  item['gender']=parts[1]
  item['race']=parts[2]
  if (len(parts)==4):
    item['date_time']=parts[3]
  data.append(item)
  
utk_data=pd.DataFrame(data)  
utk_data.describe()  
age date_time gender image race
count 23708 23705 23708 23708 23708
unique 104 23479 2 23708 8
top 26 20170110173815028.jpg.chip.jpg 0 26_1_0_20170112213001988.jpg.chip.jpg 0
freq 2197 7 12391 1 10078
utk_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 23708 entries, 0 to 23707
Data columns (total 5 columns):
age          23708 non-null object
date_time    23705 non-null object
gender       23708 non-null object
image        23708 non-null object
race         23708 non-null object
dtypes: object(5)
memory usage: 926.2+ KB
utk_data.head()
age date_time gender image race
0 25 20170119172104288.jpg.chip.jpg 1 25_1_3_20170119172104288.jpg.chip.jpg 3
1 25 20170117141726361.jpg.chip.jpg 1 25_1_0_20170117141726361.jpg.chip.jpg 0
2 27 20170116001407357.jpg.chip.jpg 0 27_0_1_20170116001407357.jpg.chip.jpg 1
3 10 20170103200501766.jpg.chip.jpg 0 10_0_4_20170103200501766.jpg.chip.jpg 4
4 26 20170116184024662.jpg.chip.jpg 1 26_1_0_20170116184024662.jpg.chip.jpg 0

we do not need date time , so delete it

del utk_data['date_time']

define a category for ages and apply it to the dataset

def age_cat_fn(age):
  age=int(age)
  
  if (0<age<18):
    return 0
  elif(18<=age<=25):
    return 1
  elif (25<age<=39):
    return 2
  elif (39< age <=49):
    return 3
  elif (49 < age <=60):
    return 4
  elif age>60:
    return 5
    
utk_data['age_cat']=utk_data.age.map(age_cat_fn)
utk_data.to_csv('utk_data.csv',sep=',')
!cp 'utk_data.csv' '/content/drive/My Drive/EIP3/session7'

data with age category

utk_data.head()
age gender image race age_cat
0 25 1 25_1_3_20170119172104288.jpg.chip.jpg 3 1
1 25 1 25_1_0_20170117141726361.jpg.chip.jpg 0 1
2 27 0 27_0_1_20170116001407357.jpg.chip.jpg 1 2
3 10 0 10_0_4_20170103200501766.jpg.chip.jpg 4 0
4 26 1 26_1_0_20170116184024662.jpg.chip.jpg 0 2

split our data into two Domains

  1. Young : age category 1
  2. Old : Age category 4
data_A=utk_data[utk_data.age_cat==1]
data_B=utk_data[utk_data.age_cat==4]
print(data_A['age_cat'].count())
print(data_B['age_cat'].count())
3404
2592

Get the images belonging to the two Domains and save as a compressed numpy array so that we can load them when necesary instead of processing the UTKFace dataset multiple times

image_dir='./UTKFace/'


image_paths_A = data_A['image'].tolist()
image_paths_B = data_B['image'].tolist()

#print(image_paths[:10])
images_A=None
images_B=None
#store 2000 images for A
for i, image_path in enumerate(image_paths_A):
  if (i<2000):
    if (i%1000==0):
      print("processing set A image num "+str(i))
    try:
      # Load image
      loaded_image = image.load_img(image_dir+image_path, target_size=(128,128,3))

      # Convert PIL image to numpy ndarray
      loaded_image = image.img_to_array(loaded_image)

      # Add another dimension (Add batch dimension)
      loaded_image = np.expand_dims(loaded_image, axis=0)

      # Concatenate all images into one tensor
      if images_A is None:
        images_A = loaded_image
      else:
        images_A = np.concatenate([images_A, loaded_image], axis=0)
    except Exception as e:
      print("Error:", i, e)

#store 2000 images for B
for i, image_path in enumerate(image_paths_B):
  if (i<2000):
    if (i%999==0):
      print("processing set B image num "+str(i))
    try:
      # Load image
      loaded_image = image.load_img(image_dir+image_path, target_size=(128,128,3))

      # Convert PIL image to numpy ndarray
      loaded_image = image.img_to_array(loaded_image)

      # Add another dimension (Add batch dimension)
      loaded_image = np.expand_dims(loaded_image, axis=0)

      # Concatenate all images into one tensor
      if images_B is None:
        images_B = loaded_image
      else:
        images_B = np.concatenate([images_B, loaded_image], axis=0)
    except Exception as e:
      print("Error:", i, e)   

 
processing set A image num 0
processing set A image num 1000
processing set B image num 0
processing set B image num 999
processing set B image num 1998
print('images_A :') 
print(images_A.shape)  
print('images_B :')
print(images_B.shape)
images_A :
(2000, 128, 128, 3)
images_B :
(2000, 128, 128, 3)
filename = '/content/drive/My Drive/EIP3/session7/utkface_128.npz'
savez_compressed(filename, images_A, images_B)
print('Saved dataset: ', filename)
Saved dataset:  /content/drive/My Drive/EIP3/session7/utkface_128.npz

Load the saved numpy arrays and plot some images from either domain

from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('/content/drive/My Drive/EIP3/session7/utkface_128.npz')
dataA, dataB = data['arr_0'], data['arr_1']
print('Loaded: ', dataA.shape, dataB.shape)
# plot source images
n_samples = 3
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + i)
	pyplot.axis('off')
	pyplot.imshow(dataA[i].astype('uint8'))
# plot target image
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + n_samples + i)
	pyplot.axis('off')
	pyplot.imshow(dataB[i].astype('uint8'))
pyplot.show()
Loaded:  (2000, 128, 128, 3) (2000, 128, 128, 3)

Install keras-contrib so that we can use InstanceNormalization instead of BatchNormalization

!pip install git+https://www.github.com/keras-team/keras-contrib.git
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.initializers import RandomNormal

Define helper functions for the various components of the Model that we are going to build

Conv layers

def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME',name='c2d'):
  return Conv2D(output_dim,kernel_size=ks,strides=s,padding=padding,kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),name=name)(input_)

Leaky Relu

def lrelu(input_,name='lr'):
  return LeakyReLU(alpha=0.2,name=name)(input_)

InstanceNormalization

def iNorm(input_,name='iNorm'):
  return InstanceNormalization(axis=-1,name=name)(input_)

Discriminator Model

def build_discriminator(image_shape):
  # weight initialization
	#init = RandomNormal(stddev=0.02)
	# source image input
	in_image = Input(shape=image_shape)
	#C1
	d1 = lrelu(conv2d(in_image,64,4,name='d_c1'),'lr1' )
	 
	# C2
	d2 = lrelu(iNorm(conv2d(d1,128,4,name='d_c2'),'iN2'),'lr2') 
	
	# C3
	d3 = lrelu(iNorm(conv2d(d1,256,4,name='d_c3'),'iN3'),'lr3')
	
	# C4
	d4 = lrelu(iNorm(conv2d(d3,512,4,name='d_c4'),'iN4'),'lr4')
	
	'''
	# second last output layer
	d = conv2d(in_image,128,3,1) 
	d = iNorm(d)  
	d = lrelu(d) 
	'''

	#  output

	d5 = conv2d(d4,1,4,1,name='d_c5')  #Conv2D(1, 4,1, padding='same', kernel_initializer=init)(d)
	# define model
	model = Model(in_image, d5)
	# compile model
	model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])
	return model
disc=build_discriminator(dataB[0].shape)
disc.summary()
WARNING: Logging before flag parsing goes to stderr.
W0823 12:23:56.827917 140499755456384 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0823 12:23:56.872185 140499755456384 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0823 12:23:57.044422 140499755456384 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
d_c1 (Conv2D)                (None, 64, 64, 64)        3136      
_________________________________________________________________
lr1 (LeakyReLU)              (None, 64, 64, 64)        0         
_________________________________________________________________
d_c3 (Conv2D)                (None, 32, 32, 256)       262400    
_________________________________________________________________
iN3 (InstanceNormalization)  (None, 32, 32, 256)       512       
_________________________________________________________________
lr3 (LeakyReLU)              (None, 32, 32, 256)       0         
_________________________________________________________________
d_c4 (Conv2D)                (None, 16, 16, 512)       2097664   
_________________________________________________________________
iN4 (InstanceNormalization)  (None, 16, 16, 512)       1024      
_________________________________________________________________
lr4 (LeakyReLU)              (None, 16, 16, 512)       0         
_________________________________________________________________
d_c5 (Conv2D)                (None, 16, 16, 1)         8193      
=================================================================
Total params: 2,372,929
Trainable params: 2,372,929
Non-trainable params: 0
_________________________________________________________________
function to add padding
def padd3(input_):
  import tensorflow as tf
  return tf.pad(input_, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")

def padd1(input_):
  import tensorflow as tf
  return tf.pad(input_, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")  

The generator uses Resnet Blocks , as defined below

from keras.layers import Add,Lambda
def res_block(input_,nf=64,ks=3,s=1,name='res_blk'):
  p=int((ks-1)/2)
  y=Lambda(padd1)(input_) #(tf.pad(input_,[[0,0],[p,p],[p,p],[0,0]],'REFLECT'))
  
  y=iNorm(conv2d(y,nf,ks,s,padding='VALID',name=name+'_c1'),name=name+'_iN1')
  
  y=Lambda(padd1)(y) #(tf.pad(tf.nn.relu(y),[[0,0],[p,p],[p,p],[0,0]],'REFLECT'))
  
  y=iNorm(conv2d(y,nf,ks,s,padding='VALID',name=name+'_c2'),name=name+'_iN2')
  
  
  y1=keras.layers.Add()([y,input_])
  
  return y1

deconvolution layers

def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME',name='dc2d'):
  #Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
  
  dcv=Conv2DTranspose(output_dim,(ks,ks),strides=(s,s),padding=padding,kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),name=name)(input_)
  
  return dcv

generator model

from keras.layers import Lambda,Conv2DTranspose

def build_generator(image_shape):
  nf=64 # num filters for first layer 
  input_=Input(shape=(128,128,3))
  c0 = Lambda(padd3)(input_)
  
  c1 = Activation('relu')(iNorm(conv2d(c0, nf, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn'))
  c2 = Activation('relu')(iNorm(conv2d(c1, nf*2, 3, 2, name='g_e2_c'), 'g_e2_bn'))
  c3 = Activation('relu')(iNorm(conv2d(c2, nf*4 , 3, 2, name='g_e3_c'), 'g_e3_bn'))

  r1 = res_block(c3, nf*4, name='g_r1')
  r2 = res_block(r1, nf*4, name='g_r2')
  r3 = res_block(r2, nf*4, name='g_r3')
  r4 = res_block(r3, nf*4, name='g_r4')
  r5 = res_block(r4, nf*4, name='g_r5')
  r6 = res_block(r5, nf*4, name='g_r6')
  r7 = res_block(r6, nf*4, name='g_r7')
  r8 = res_block(r7, nf*4, name='g_r8')
  r9 = res_block(r8, nf*4, name='g_r9')
  
  
  d1=Conv2DTranspose(nf*2, (3,3), strides=(2,2), padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),name='g_d1_dc')(r9)
  
  d1=Activation('relu')(iNorm(d1,name='g_d1_bn'))

  d2=Conv2DTranspose(nf, (3,3), strides=(2,2), padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),name='g_d2_dc')(d1)
  

  d2=Activation('relu')(iNorm(d2,name='g_d2_bn'))

  d2 = Lambda(padd3)(d2)#(tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT"))

  

  d3=conv2d(d2, 3 , 7, 1, padding='VALID', name='g_pred_c')

  

  pred=Activation('tanh')(d3)

  model=Model(input_,pred)
  

  return model
gen=build_generator(dataA[0].shape)
gen.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
lambda_20 (Lambda)              (None, 134, 134, 3)  0           input_3[0][0]                    
__________________________________________________________________________________________________
g_e1_c (Conv2D)                 (None, 128, 128, 64) 9472        lambda_20[0][0]                  
__________________________________________________________________________________________________
g_e1_bn (InstanceNormalization) (None, 128, 128, 64) 128         g_e1_c[0][0]                     
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 128, 128, 64) 0           g_e1_bn[0][0]                    
__________________________________________________________________________________________________
g_e2_c (Conv2D)                 (None, 64, 64, 128)  73856       activation_4[0][0]               
__________________________________________________________________________________________________
g_e2_bn (InstanceNormalization) (None, 64, 64, 128)  256         g_e2_c[0][0]                     
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 64, 64, 128)  0           g_e2_bn[0][0]                    
__________________________________________________________________________________________________
g_e3_c (Conv2D)                 (None, 32, 32, 256)  295168      activation_5[0][0]               
__________________________________________________________________________________________________
g_e3_bn (InstanceNormalization) (None, 32, 32, 256)  512         g_e3_c[0][0]                     
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 32, 32, 256)  0           g_e3_bn[0][0]                    
__________________________________________________________________________________________________
lambda_21 (Lambda)              (None, 34, 34, 256)  0           activation_6[0][0]               
__________________________________________________________________________________________________
g_r1_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_21[0][0]                  
__________________________________________________________________________________________________
g_r1_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r1_c1[0][0]                    
__________________________________________________________________________________________________
lambda_22 (Lambda)              (None, 34, 34, 256)  0           g_r1_iN1[0][0]                   
__________________________________________________________________________________________________
g_r1_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_22[0][0]                  
__________________________________________________________________________________________________
g_r1_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r1_c2[0][0]                    
__________________________________________________________________________________________________
add_10 (Add)                    (None, 32, 32, 256)  0           g_r1_iN2[0][0]                   
                                                                 activation_6[0][0]               
__________________________________________________________________________________________________
lambda_23 (Lambda)              (None, 34, 34, 256)  0           add_10[0][0]                     
__________________________________________________________________________________________________
g_r2_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_23[0][0]                  
__________________________________________________________________________________________________
g_r2_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r2_c1[0][0]                    
__________________________________________________________________________________________________
lambda_24 (Lambda)              (None, 34, 34, 256)  0           g_r2_iN1[0][0]                   
__________________________________________________________________________________________________
g_r2_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_24[0][0]                  
__________________________________________________________________________________________________
g_r2_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r2_c2[0][0]                    
__________________________________________________________________________________________________
add_11 (Add)                    (None, 32, 32, 256)  0           g_r2_iN2[0][0]                   
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
lambda_25 (Lambda)              (None, 34, 34, 256)  0           add_11[0][0]                     
__________________________________________________________________________________________________
g_r3_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_25[0][0]                  
__________________________________________________________________________________________________
g_r3_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r3_c1[0][0]                    
__________________________________________________________________________________________________
lambda_26 (Lambda)              (None, 34, 34, 256)  0           g_r3_iN1[0][0]                   
__________________________________________________________________________________________________
g_r3_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_26[0][0]                  
__________________________________________________________________________________________________
g_r3_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r3_c2[0][0]                    
__________________________________________________________________________________________________
add_12 (Add)                    (None, 32, 32, 256)  0           g_r3_iN2[0][0]                   
                                                                 add_11[0][0]                     
__________________________________________________________________________________________________
lambda_27 (Lambda)              (None, 34, 34, 256)  0           add_12[0][0]                     
__________________________________________________________________________________________________
g_r4_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_27[0][0]                  
__________________________________________________________________________________________________
g_r4_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r4_c1[0][0]                    
__________________________________________________________________________________________________
lambda_28 (Lambda)              (None, 34, 34, 256)  0           g_r4_iN1[0][0]                   
__________________________________________________________________________________________________
g_r4_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_28[0][0]                  
__________________________________________________________________________________________________
g_r4_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r4_c2[0][0]                    
__________________________________________________________________________________________________
add_13 (Add)                    (None, 32, 32, 256)  0           g_r4_iN2[0][0]                   
                                                                 add_12[0][0]                     
__________________________________________________________________________________________________
lambda_29 (Lambda)              (None, 34, 34, 256)  0           add_13[0][0]                     
__________________________________________________________________________________________________
g_r5_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_29[0][0]                  
__________________________________________________________________________________________________
g_r5_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r5_c1[0][0]                    
__________________________________________________________________________________________________
lambda_30 (Lambda)              (None, 34, 34, 256)  0           g_r5_iN1[0][0]                   
__________________________________________________________________________________________________
g_r5_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_30[0][0]                  
__________________________________________________________________________________________________
g_r5_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r5_c2[0][0]                    
__________________________________________________________________________________________________
add_14 (Add)                    (None, 32, 32, 256)  0           g_r5_iN2[0][0]                   
                                                                 add_13[0][0]                     
__________________________________________________________________________________________________
lambda_31 (Lambda)              (None, 34, 34, 256)  0           add_14[0][0]                     
__________________________________________________________________________________________________
g_r6_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_31[0][0]                  
__________________________________________________________________________________________________
g_r6_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r6_c1[0][0]                    
__________________________________________________________________________________________________
lambda_32 (Lambda)              (None, 34, 34, 256)  0           g_r6_iN1[0][0]                   
__________________________________________________________________________________________________
g_r6_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_32[0][0]                  
__________________________________________________________________________________________________
g_r6_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r6_c2[0][0]                    
__________________________________________________________________________________________________
add_15 (Add)                    (None, 32, 32, 256)  0           g_r6_iN2[0][0]                   
                                                                 add_14[0][0]                     
__________________________________________________________________________________________________
lambda_33 (Lambda)              (None, 34, 34, 256)  0           add_15[0][0]                     
__________________________________________________________________________________________________
g_r7_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_33[0][0]                  
__________________________________________________________________________________________________
g_r7_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r7_c1[0][0]                    
__________________________________________________________________________________________________
lambda_34 (Lambda)              (None, 34, 34, 256)  0           g_r7_iN1[0][0]                   
__________________________________________________________________________________________________
g_r7_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_34[0][0]                  
__________________________________________________________________________________________________
g_r7_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r7_c2[0][0]                    
__________________________________________________________________________________________________
add_16 (Add)                    (None, 32, 32, 256)  0           g_r7_iN2[0][0]                   
                                                                 add_15[0][0]                     
__________________________________________________________________________________________________
lambda_35 (Lambda)              (None, 34, 34, 256)  0           add_16[0][0]                     
__________________________________________________________________________________________________
g_r8_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_35[0][0]                  
__________________________________________________________________________________________________
g_r8_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r8_c1[0][0]                    
__________________________________________________________________________________________________
lambda_36 (Lambda)              (None, 34, 34, 256)  0           g_r8_iN1[0][0]                   
__________________________________________________________________________________________________
g_r8_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_36[0][0]                  
__________________________________________________________________________________________________
g_r8_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r8_c2[0][0]                    
__________________________________________________________________________________________________
add_17 (Add)                    (None, 32, 32, 256)  0           g_r8_iN2[0][0]                   
                                                                 add_16[0][0]                     
__________________________________________________________________________________________________
lambda_37 (Lambda)              (None, 34, 34, 256)  0           add_17[0][0]                     
__________________________________________________________________________________________________
g_r9_c1 (Conv2D)                (None, 32, 32, 256)  590080      lambda_37[0][0]                  
__________________________________________________________________________________________________
g_r9_iN1 (InstanceNormalization (None, 32, 32, 256)  512         g_r9_c1[0][0]                    
__________________________________________________________________________________________________
lambda_38 (Lambda)              (None, 34, 34, 256)  0           g_r9_iN1[0][0]                   
__________________________________________________________________________________________________
g_r9_c2 (Conv2D)                (None, 32, 32, 256)  590080      lambda_38[0][0]                  
__________________________________________________________________________________________________
g_r9_iN2 (InstanceNormalization (None, 32, 32, 256)  512         g_r9_c2[0][0]                    
__________________________________________________________________________________________________
add_18 (Add)                    (None, 32, 32, 256)  0           g_r9_iN2[0][0]                   
                                                                 add_17[0][0]                     
__________________________________________________________________________________________________
g_d1_dc (Conv2DTranspose)       (None, 64, 64, 128)  295040      add_18[0][0]                     
__________________________________________________________________________________________________
g_d1_bn (InstanceNormalization) (None, 64, 64, 128)  256         g_d1_dc[0][0]                    
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 64, 64, 128)  0           g_d1_bn[0][0]                    
__________________________________________________________________________________________________
g_d2_dc (Conv2DTranspose)       (None, 128, 128, 64) 73792       activation_7[0][0]               
__________________________________________________________________________________________________
g_d2_bn (InstanceNormalization) (None, 128, 128, 64) 128         g_d2_dc[0][0]                    
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 128, 128, 64) 0           g_d2_bn[0][0]                    
__________________________________________________________________________________________________
lambda_39 (Lambda)              (None, 134, 134, 64) 0           activation_8[0][0]               
__________________________________________________________________________________________________
g_pred_c (Conv2D)               (None, 128, 128, 3)  9411        lambda_39[0][0]                  
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 128, 128, 3)  0           g_pred_c[0][0]                   
==================================================================================================
Total params: 11,388,675
Trainable params: 11,388,675
Non-trainable params: 0
__________________________________________________________________________________________________

composite Model with two genartors and discriminator

def build_composite_model(g_model_1, d_model, g_model_2, image_shape):
	# ensure the model we're updating is trainable
	g_model_1.trainable = True
	# mark discriminator as not trainable
	d_model.trainable = False
	# mark other generator model as not trainable
	g_model_2.trainable = False
	# discriminator element
	input_gen = Input(shape=image_shape)
	gen1_out = g_model_1(input_gen)
	output_d = d_model(gen1_out)
	# identity element
	input_id = Input(shape=image_shape)
	output_id = g_model_1(input_id)
	# forward cycle
	output_f = g_model_2(gen1_out)
	# backward cycle
	gen2_out = g_model_2(input_id)
	output_b = g_model_1(gen2_out)
	# define model graph
	model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])
	# define optimization algorithm configuration
	opt = Adam(lr=0.0002, beta_1=0.5)
	# compile model with weighting of least squares loss and L1 loss
	model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)
	return model

The original samples are over 3100 per domain and it is increasing the time for each epoch(has proven problematic in the initial training runs). So we will use a function to get a subsample of the training data , 1000 per Domain

def get_subsample(dataset):

    t1=np.random.randint(900)
    t2=np.random.randint(1200,2000)
    t3=np.random.randint(2500,2800)
    return np.vstack((dataset[0][t1:t1+300],dataset[0][t2:t2+400],dataset[0][t3:t3+300])),np.vstack((dataset[1][t1:t1+300], dataset[1][t2:t2+400],dataset[1][t3:t3+300]))
 
 
def get_subsample2(dataset):
  t0=np.random.randint(250)
  t1=np.random.randint(300)
  t2=np.random.randint(1200,2000)
  t3=np.random.randint(2500,2800)
  return np.vstack((dataset[0][t1:t1+800],dataset[0][t2:t2+200])),np.vstack((dataset[1][t1:t1+100], dataset[2][t0:t0+900]))
 

Utility Functions to load the image samples , generate fake images , save Models , Save genrated images , etc

def load_real_samples2(filename):
  data = load(filename)
  X1,X2,X3 = data['arr_0'],data['arr_1'],data['arr_2']
  X1= (X1-127.5)/127.5
  X2 = (X2-127.5)/127.5
  X3 = (X3-127.5)/127.5
  return X1,X2,X3
def load_real_samples(filename):
	# load the dataset
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]
 
# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# retrieve selected images
	X = dataset[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return X, y
 
# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, dataset, patch_shape):
	# generate fake instance
	X = g_model.predict(dataset)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y
 
# save the generator models to file
def save_models(step, g_model_AtoB, g_model_BtoA):
	# save the first generator model
	filename1 = '/content/drive/My Drive/EIP3/session7/g_model_AtoB_%06d.h5' % (step+1)
	g_model_AtoB.save(filename1)
	# save the second generator model
	filename2 = '/content/drive/My Drive/EIP3/session7/g_model_BtoA_%06d.h5' % (step+1)
	g_model_BtoA.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))
  
# save the generator models to file
def save_models2(step, g_model_AtoB, g_model_BtoA,d_model_A,d_model_B):
	# save the first generator model
	filename1 = '/content/drive/My Drive/EIP3/session7/g_model_AtoB_%06d.h5' % (step+1)
	g_model_AtoB.save(filename1)
	# save the second generator model
	filename2 = '/content/drive/My Drive/EIP3/session7/g_model_BtoA_%06d.h5' % (step+1)
	g_model_BtoA.save(filename2)
  # save the first discriminator model A
	filename3 = '/content/drive/My Drive/EIP3/session7/d_model_A_%06d.h5' % (step+1)
	d_model_A.save(filename3)
  # save the first discriminator model B
	filename4 = '/content/drive/My Drive/EIP3/session7/d_model_B_%06d.h5' % (step+1)
	d_model_B.save(filename4)
	print('>Saved: %s , %s , %s and %s' % (filename1, filename2,filename3,filename4))  
  
def summarize_performance(step, g_model, trainX, name, n_samples=5):

	pyplot.figure(
    figsize=(15, 8),
    dpi=120)
 
	# select a sample of input images
	X_in, _ = generate_real_samples(trainX, n_samples, 0)
	# generate translated images
	X_out, _ = generate_fake_samples(g_model, X_in, 0)
	# scale all pixels from [-1,1] to [0,1]
	X_in = (X_in + 1) / 2.0
	X_out = (X_out + 1) / 2.0
	# plot real images
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_in[i])
	# plot translated image
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_out[i])
	# save plot to file
	filename1 = '/content/drive/My Drive/EIP3/session7/%s_generated_plot_%06d.png' % (name, (step+1))
	pyplot.savefig(filename1)
	pyplot.close()

Maintain a pool of 50 images as described in the paper

def update_image_pool(pool, images, max_size=50):
	selected = list()
	for image in images:
		if len(pool) < max_size:
			# stock the pool
			pool.append(image)
			selected.append(image)
		elif random() < 0.5:
			# use image, but don't add it to the pool
			selected.append(image)
		else:
			# replace an existing image and use replaced image
			ix = randint(0, len(pool))
			selected.append(pool[ix])
			pool[ix] = image
	return asarray(selected)

function to run the training

def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset,batch_size,n_epochs):
	# define properties of the training run
	n_epochs, n_batch, = n_epochs, batch_size
	# determine the output square shape of the discriminator
	n_patch = d_model_A.output_shape[1]
	# unpack dataset
	trainA, trainB = get_subsample(dataset)
	# prepare image pool for fakes
	poolA, poolB = list(), list()
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch)
		X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch)
		X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch)
		# update fakes from pool
		X_fakeA = update_image_pool(poolA, X_fakeA)
		X_fakeB = update_image_pool(poolB, X_fakeB)
		# update generator B->A via adversarial and cycle loss
		g_loss2, _, _, _, _  = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
		# update discriminator for A -> [real/fake]
		dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
		dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)
		# update generator A->B via adversarial and cycle loss
		g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
		# update discriminator for B -> [real/fake]
		dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
		dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)
		# summarize performance
		print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))
		# evaluate the model performance every so often
		if (i+1) % (bat_per_epo * 1) == 0:
			# plot A->B translation
			summarize_performance(i, g_model_AtoB, trainA, 'AtoB')
			# plot B->A translation
			summarize_performance(i, g_model_BtoA, trainB, 'BtoA')
		if (i+1) % (bat_per_epo * 5) == 0:
			# save the models
			save_models2(i, g_model_AtoB, g_model_BtoA,d_model_A,d_model_B)

define the models and run training

from random import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from keras.optimizers import Adam
# load image data
dataset = load_real_samples('/content/drive/My Drive/EIP3/session7/utkface_128.npz')
print('Loaded', dataset[0].shape, dataset[1].shape)
# define input shape based on the loaded dataset
image_shape = dataset[0].shape[1:]
# generator: A -> B
g_model_AtoB = build_generator(image_shape)
# generator: B -> A
g_model_BtoA = build_generator(image_shape)
# discriminator: A -> [real/fake]
d_model_A = build_discriminator(image_shape)
# discriminator: B -> [real/fake]
d_model_B = build_discriminator(image_shape)
# composite: A -> B -> [real/fake, A]
c_model_AtoB = build_composite_model(g_model_AtoB, d_model_B, g_model_BtoA, image_shape)
# composite: B -> A -> [real/fake, B]
c_model_BtoA = build_composite_model(g_model_BtoA, d_model_A, g_model_AtoB, image_shape)
# train models
train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset,batch_size=2,n_epochs=10)
# load image data
dataset = load_real_samples('/content/drive/My Drive/EIP3/session7/utkface_128.npz')
print('Loaded', dataset[0].shape, dataset[1].shape)
# define input shape based on the loaded dataset
image_shape = dataset[0].shape[1:]

#load the previously trained model 

cust = {'InstanceNormalization': InstanceNormalization, 'tf': tf}
# generator: A -> B
g_model_AtoB = load_model('/content/drive/My Drive/EIP3/session7/g_model_AtoB_005625.h5', cust)
# generator: B -> A
g_model_BtoA = load_model('/content/drive/My Drive/EIP3/session7/g_model_BtoA_005625.h5', cust)
 
# discriminator: A -> [real/fake]
d_model_A = build_discriminator(image_shape)
# discriminator: B -> [real/fake]
d_model_B = build_discriminator(image_shape)
# composite: A -> B -> [real/fake, A]
c_model_AtoB = build_composite_model(g_model_AtoB, d_model_B, g_model_BtoA, image_shape)
# composite: B -> A -> [real/fake, B]
c_model_BtoA = build_composite_model(g_model_BtoA, d_model_A, g_model_AtoB, image_shape)
dataset = load_real_samples2('/content/drive/My Drive/EIP3/session7/utkface_128_2.npz')

After training the model for slighly over 100 epochs (Although more epochs will give better results) , we will try out the results of this training using the show_results function.

def show_results( g_model, trainX, n_samples=5,title='A to B'):
  pyplot.figure(
    figsize=(12, 6),
    dpi=120)
  # select a sample of input images
  X_in, _ = generate_real_samples(trainX, n_samples, 0)
	# generate translated images
  X_out, _ = generate_fake_samples(g_model, X_in, 0)
	# scale all pixels from [-1,1] to [0,1]
  X_in = (X_in + 1) / 2.0
  X_out = (X_out + 1) / 2.0
	# plot real images
  #pyplot.title(title)
  for i in range(n_samples):
    pyplot.subplot(2, n_samples, 1 + i)
    pyplot.axis('off')
    pyplot.imshow(X_in[i])
  pyplot.show()
  print("      ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ "+title+" ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓")
	# plot translated image
  pyplot.figure(
    figsize=(12, 6),
    dpi=120)
  for i in range(n_samples):
    pyplot.subplot(2, n_samples, 1 + n_samples + i)
    pyplot.axis('off')
    pyplot.imshow(X_out[i])
  pyplot.show()

  
  

Here are some samples of generated images :

Old to young sample 1


Old to young sample 2


Old to young sample 3


Old to young sample 4


Training the models for more epochs will make the results better, especially for Young to Old Translation . Also we used 128x128 images due to time and compute constraints . Training on the original 200x200 image size would have yielded better results