diff --git a/src/ray-tune-keras-cifar10.py b/src/ray-tune-keras-cifar10.py new file mode 100644 index 0000000..c1a0cc4 --- /dev/null +++ b/src/ray-tune-keras-cifar10.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Train keras CNN on the CIFAR10 small images dataset. + +The model comes from: https://zhuanlan.zhihu.com/p/29214791, +and it gets to about 87% validation accuracy in 100 epochs. + +Note that the script requires a machine with 4 GPUs. You +can set {"gpu": 0} to use CPUs for training, although +it is less efficient. +""" + +from __future__ import print_function + +import argparse + +import numpy as np +import tensorflow as tf +from tensorflow.keras.datasets import cifar10 +from tensorflow.keras.layers import ( + Convolution2D, + Dense, + Dropout, + Flatten, + Input, + MaxPooling2D, +) +from tensorflow.keras.models import Model, load_model +from tensorflow.keras.preprocessing.image import ImageDataGenerator + +from ray import train, tune +from ray.tune import Trainable +from ray.tune.schedulers import PopulationBasedTraining + +num_classes = 10 +NUM_SAMPLES = 128 + + +class Cifar10Model(Trainable): + def _read_data(self): + # The data, split between train and test sets: + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + + # Convert class vectors to binary class matrices. + y_train = tf.keras.utils.to_categorical(y_train, num_classes) + y_test = tf.keras.utils.to_categorical(y_test, num_classes) + + x_train = x_train.astype("float32") + x_train /= 255 + x_test = x_test.astype("float32") + x_test /= 255 + + return (x_train, y_train), (x_test, y_test) + + def _build_model(self, input_shape): + x = Input(shape=(32, 32, 3)) + y = x + y = Convolution2D( + filters=64, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer="he_normal", + )(y) + y = Convolution2D( + filters=64, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer="he_normal", + )(y) + y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y) + + y = Convolution2D( + filters=128, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer="he_normal", + )(y) + y = Convolution2D( + filters=128, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer="he_normal", + )(y) + y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y) + + y = Convolution2D( + filters=256, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer="he_normal", + )(y) + y = Convolution2D( + filters=256, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer="he_normal", + )(y) + y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y) + + y = Flatten()(y) + y = Dropout(self.config.get("dropout", 0.5))(y) + y = Dense(units=10, activation="softmax", kernel_initializer="he_normal")(y) + + model = Model(inputs=x, outputs=y, name="model1") + return model + + def setup(self, config): + self.train_data, self.test_data = self._read_data() + x_train = self.train_data[0] + model = self._build_model(x_train.shape[1:]) + + opt = tf.keras.optimizers.Adadelta( + lr=self.config.get("lr", 1e-4), weight_decay=self.config.get("decay", 1e-4) + ) + model.compile( + loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"] + ) + self.model = model + + def step(self): + x_train, y_train = self.train_data + x_train, y_train = x_train[:NUM_SAMPLES], y_train[:NUM_SAMPLES] + x_test, y_test = self.test_data + x_test, y_test = x_test[:NUM_SAMPLES], y_test[:NUM_SAMPLES] + + aug_gen = ImageDataGenerator( + # set input mean to 0 over the dataset + featurewise_center=False, + # set each sample mean to 0 + samplewise_center=False, + # divide inputs by dataset std + featurewise_std_normalization=False, + # divide each input by its std + samplewise_std_normalization=False, + # apply ZCA whitening + zca_whitening=False, + # randomly rotate images in the range (degrees, 0 to 180) + rotation_range=0, + # randomly shift images horizontally (fraction of total width) + width_shift_range=0.1, + # randomly shift images vertically (fraction of total height) + height_shift_range=0.1, + # randomly flip images + horizontal_flip=True, + # randomly flip images + vertical_flip=False, + ) + + aug_gen.fit(x_train) + batch_size = self.config.get("batch_size", 64) + gen = aug_gen.flow(x_train, y_train, batch_size=batch_size) + self.model.fit_generator( + generator=gen, epochs=self.config.get("epochs", 1), validation_data=None + ) + + # loss, accuracy + _, accuracy = self.model.evaluate(x_test, y_test, verbose=0) + return {"mean_accuracy": accuracy} + + def save_checkpoint(self, checkpoint_dir): + file_path = checkpoint_dir + "/model" + self.model.save(file_path) + + def load_checkpoint(self, checkpoint_dir): + # See https://stackoverflow.com/a/42763323 + del self.model + file_path = checkpoint_dir + "/model" + self.model = load_model(file_path) + + def cleanup(self): + # If need, save your model when exit. + # saved_path = self.model.save(self.logdir) + # print("save model at: ", saved_path) + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) + args, _ = parser.parse_known_args() + + space = { + "epochs": 1, + "batch_size": 64, + "lr": tune.grid_search([10**-4, 10**-5]), + "decay": tune.sample_from(lambda spec: spec.config.lr / 100.0), + "dropout": tune.grid_search([0.25, 0.5]), + } + if args.smoke_test: + space["lr"] = 10**-4 + space["dropout"] = 0.5 + + perturbation_interval = 10 + pbt = PopulationBasedTraining( + time_attr="training_iteration", + perturbation_interval=perturbation_interval, + hyperparam_mutations={ + "dropout": lambda _: np.random.uniform(0, 1), + }, + ) + + tuner = tune.Tuner( + tune.with_resources( + Cifar10Model, + resources={"cpu": 1, "gpu": 1}, + ), + run_config=train.RunConfig( + name="pbt_cifar10", + stop={ + "mean_accuracy": 0.80, + "training_iteration": 30, + }, + checkpoint_config=train.CheckpointConfig( + checkpoint_frequency=perturbation_interval, + checkpoint_score_attribute="mean_accuracy", + num_to_keep=2, + ), + ), + tune_config=tune.TuneConfig( + scheduler=pbt, + num_samples=4, + metric="mean_accuracy", + mode="max", + reuse_actors=True, + ), + param_space=space, + ) + results = tuner.fit() + print("Best hyperparameters found were: ", results.get_best_result().config) \ No newline at end of file