Transform Pokemon Resolution
Se trata de una herramienta capaz de reenfocar Pokémon que están demasiado desenfocados utilizando la arquitectura pix2pix.
Resultados
Codigo
# -*- coding: utf-8 -*-
"""Untitled2.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1Vn6TKRB7i2O06SLPc3WLiD_Ht-al89Ha
"""
# Commented out IPython magic to ensure Python compatibility.
# %tensorflow_version 2.x
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
#Ruta raiz
PATH = '/content/drive/MyDrive/Transform-Pokemon-Resolution'
#Ruta de entrada
INPATH= PATH + "/Input Pokemon"
#Ruta de salida
OUPATH = PATH + "/outputPoke"
#Check Poins
CKPATH = PATH + '/checkpoints'
imgurls = !ls -1 "{INPATH}"
n = 151 #La cantidad de imagenes que voy a utilizar
train_n = round(n * 0.80)#El porcentaje de images diferentes que voy a tener
#Lista randomizada
randurls = np.copy(imgurls)
np.random.seed(23)
np.random.shuffle(randurls)
#Particion train / test
tr_urls = randurls[:train_n]
ts_urls = randurls[train_n:n]
#151 en la carpeta, 121 para entrenar y 30 diferentas
print(len(imgurls), len(tr_urls), len(ts_urls))
IMG_WIDTH = 256
IMG_HEIGHT = 256
#Reescalamos las imagenes
def resize(inimg, tgimg, height, width):
inimg = tf.image.resize(inimg, [height, width])
tgimg = tf.image.resize(tgimg, [height, width])
return inimg, tgimg
#Normaliza el rango [-1, +1] la imagen
def normalize(inimg, tgimg):
inimg = (inimg/127.5) - 1
tgimg = (tgimg/127.5) - 1
return inimg, tgimg
#Aumentacion de datos: Random Crop + Flip
def random_jitter(inimg, tgimg):
inimg, tgimg = resize(inimg, tgimg, 286, 286)
stacked_image = tf.stack([inimg, tgimg], axis = 0)
cropped_image = tf.image.random_crop(stacked_image, size = [2, IMG_HEIGHT, IMG_WIDTH, 3])
inimg, tgimg = cropped_image[0], cropped_image[1]
if tf.random.uniform(()) > 0.5:
inimg = tf.image.flip_left_right(inimg)
tgimg = tf.image.flip_left_right(tgimg)
return inimg, tgimg
def load_image(filename, augment=True):
inimg = tf.cast(tf.image.decode_jpeg(tf.io.read_file(INPATH + '/' + filename)), tf.float32)[..., :3]
tgimg = tf.cast(tf.image.decode_jpeg(tf.io.read_file(OUPATH + '/' + filename)), tf.float32)[..., :3]
inimg, tgimg = resize(inimg, tgimg, IMG_HEIGHT, IMG_WIDTH)
if augment:
inimg, tgimg = random_jitter(inimg, tgimg)
inimg, tgimg = normalize(inimg, tgimg)
return inimg, tgimg
def load_train_image(filename):
return load_image(filename, True)
def load_test_image(filename):
return load_image(filename, False)
plt.imshow((load_train_image(randurls[0])[1]) + 0.1)
#Zona de Carga de datos
train_dataset = tf.data.Dataset.from_tensor_slices(tr_urls)
train_dataset = train_dataset.map(load_train_image, num_parallel_calls = tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)
"""for inimg, tgimg in train_dataset.take(5): Para probar las imagenes
plt.imshow(((tgimg[0,...]) + 1) / 2)
plt.show()"""
test_dataset = tf.data.Dataset.from_tensor_slices(tr_urls)
test_dataset = test_dataset.map(load_test_image, num_parallel_calls = tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(1)
from tensorflow.keras import *
from tensorflow.keras.layers import *
def downsample(filters, apply_batchnorm = True):
result = Sequential()
initializer = tf.random_normal_initializer(0, 0.02)
#Capa convolucional
result.add(Conv2D(filters,
kernel_size = 4,
strides = 2,
padding = "same",
kernel_initializer = initializer,
use_bias =not apply_batchnorm))
if apply_batchnorm:
#Capa de BatchNorm
result.add(BatchNormalization())
#Capa activa
result.add(LeakyReLU())
return result
downsample(64)
def upsample(filters, apply_dropout = True):
result = Sequential()
initializer = tf.random_normal_initializer(0, 0.02)
#Capa convolucional
result.add(Conv2DTranspose(filters,
kernel_size = 4,
strides = 2,
padding = "same",
kernel_initializer = initializer,
use_bias=False))
#Capa de BatchNorm
result.add(BatchNormalization())
if apply_dropout:
#Capa de Dropout
result.add(Dropout(0.5))
#Capa activa
result.add(ReLU())
return result
upsample(64)
def Generator():
inputs = tf.keras.layers.Input(shape=[None, None,3])
down_stack = [
downsample(64, apply_batchnorm=False), #(bs , 128, 128, 64)
downsample(128), #(bs , 64, 64, 128)
downsample(256), #(bs , 32, 32, 256)
downsample(512), #(bs , 16, 16, 512)
downsample(512), #(bs , 8, 16, 512)
downsample(512), #(bs , 4, 4, 512)
downsample(512), #(bs , 2, 2, 512)
downsample(512), #(bs , 1, 1, 512)
]
up_stack = [
upsample(512, apply_dropout=False), #(bs, 2, 2, 1024)
upsample(512, apply_dropout=False), #(bs, 4, 4, 1024)
upsample(512, apply_dropout=False), #(bs, 8, 8, 1024)
upsample(512), #(bs, 16, 16, 1024)
upsample(256), #(bs, 32, 32, 512)
upsample(128), #(bs, 64, 64, 256)
upsample(64), #(bs, 128, 128, 128)
]
initializer = tf.random_normal_initializer(0, 0.02)
last = Conv2DTranspose(filters=3,
kernel_size = 4,
strides = 2,
padding = "same",
kernel_initializer = initializer,
activation = "tanh"
)
x = inputs
s = []
concat = Concatenate()
for down in down_stack:
x = down(x)
s.append(x)
s = reversed(s[: -1])
for up, sk in zip(up_stack, s):
x = up(x)
x = concat([x, sk])
last = last(x)
return Model(inputs = inputs, outputs = last)
generator = Generator()
gen_output = generator(((inimg+1)*255),training=False)
plt.imshow(gen_output[0,...])
def Discriminator():
ini = Input(shape=[None, None, 3], name = "imput_img")
gen = Input(shape=[None, None, 3], name = "gener_img")
con = concatenate([ini, gen])
initializers = tf.random_normal_initializer(0, 0.02)
down1 = downsample(64, apply_batchnorm=False)(con)
down2 = downsample(128)(down1)
down3 = downsample(256)(down2)
down4 = downsample(512)(down3)
last = tf.keras.layers.Conv2D(filters=1,
kernel_size=4,
strides=1,
kernel_initializer=initializers,
padding = "same")(down4)
return tf.keras.Model(inputs=[ini, gen], outputs=last)
discriminator = Discriminator()
disc_out = discriminator([((inimg+1)*255), gen_output], training=False)
plt.imshow(disc_out[0,...,-1],vmin=-20,vmax=20,cmap='RdBu_r')
plt.colorbar()
disc_out.shape
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def descriminator_loss(disc_real_output, disc_generated_output):
#Diferencia entre los verdaderos y los que detecta el discriminador como falsos
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
#Diferencia entre los false por se generado y el detectado por el discriminador
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
total_disc_loss = real_loss + generated_loss
return total_disc_loss
LAMBA = 100
#Le pasamos ambas imagenes, la imagen generada y la que queremos tener como resultado
def generator_loss(disc_generated_output,gen_output, target):
gan_loss = loss_object(tf.ones_like(disc_generated_output),disc_generated_output)
#mean absolute error :/
l1_loss=tf.reduce_mean(tf.abs(target - gen_output))
total_gen_loss = gan_loss + (LAMBA *l1_loss)
return total_gen_loss
#Definimos los optimizadores y guardamos los checkpoins
#Guardamos los estados de entrenamientos de la red neuronal(Por si se cae)
import os
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
checkpoint_prefix = os.path.join(CKPATH, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator = generator,
discriminator = discriminator)
checkpoint.restore(tf.train.latest_checkpoint(CKPATH)) #Falta el assert_consumed() Pero no se porque tira errores
def generate_images(model, test_input, tar, save_filename = False, display_imgs=True):
prediction = model(test_input, training = True)
if save_filename:
tf.keras.preprocessing.image.save_img(PATH + 'output/' + save_filename + 'jpg', prediction[0,...])
plt.figure(figsize=(10,10))
display_list = [test_input[0], tar[0], prediction[0]]
title = ['Input Image', 'Ground Truth', 'Preditec Image']
if display_imgs:
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
#Getting the pixel values between [0, 1] to plot it
plt.imshow(display_list[i]* 0.5 + 0.5)
plt.axis('off')
plt.show()
def train_step(input_image, target):
with tf.GradientTape() as gen_tape, tf.GradientTape() as discr_tape:
output_image = generator(input_image, training = True)
output_gen_discr = discriminator([output_image, input_image], training = True)
output_trg_discr = discriminator([target, input_image], training =True)
discr_loss = descriminator_loss(output_trg_discr, output_gen_discr)
gen_loss = generator_loss(output_gen_discr, output_image, target)
generator_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
discriminator_grads = discr_tape.gradient(discr_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_grads, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_grads, discriminator.trainable_variables))
from IPython.display import clear_output
def train(dataset, epochs):
for epoch in range(epochs):
imgi=0
for input_image, target in dataset:
print('epoch ' + str(epoch) + ' - train: ' + str(imgi) + '/' + str(len(tr_urls)))
imgi+=1
train_step(input_image, target)
clear_output(wait = True)
for inp, tar in test_dataset.take(5):
generate_images(generator, inp, tar, str(imgi) + '_' + str(epoch), display_imgs=True)
#Guardo un checkpoint cada 20 procesos
if (epoch +1) % 25 == 0:
checkpoint.save(file_prefix=checkpoint_prefix)
train(train_dataset, 100)