example code from the official documentation
This commit is contained in:
parent
5a7702ec67
commit
1fd520d225
1 changed files with 243 additions and 0 deletions
243
src/ray-tune-keras-cifar10.py
Normal file
243
src/ray-tune-keras-cifar10.py
Normal file
|
@ -0,0 +1,243 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Train keras CNN on the CIFAR10 small images dataset.
|
||||
|
||||
The model comes from: https://zhuanlan.zhihu.com/p/29214791,
|
||||
and it gets to about 87% validation accuracy in 100 epochs.
|
||||
|
||||
Note that the script requires a machine with 4 GPUs. You
|
||||
can set {"gpu": 0} to use CPUs for training, although
|
||||
it is less efficient.
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.datasets import cifar10
|
||||
from tensorflow.keras.layers import (
|
||||
Convolution2D,
|
||||
Dense,
|
||||
Dropout,
|
||||
Flatten,
|
||||
Input,
|
||||
MaxPooling2D,
|
||||
)
|
||||
from tensorflow.keras.models import Model, load_model
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
from ray import train, tune
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
num_classes = 10
|
||||
NUM_SAMPLES = 128
|
||||
|
||||
|
||||
class Cifar10Model(Trainable):
|
||||
def _read_data(self):
|
||||
# The data, split between train and test sets:
|
||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||
|
||||
# Convert class vectors to binary class matrices.
|
||||
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
|
||||
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
|
||||
|
||||
x_train = x_train.astype("float32")
|
||||
x_train /= 255
|
||||
x_test = x_test.astype("float32")
|
||||
x_test /= 255
|
||||
|
||||
return (x_train, y_train), (x_test, y_test)
|
||||
|
||||
def _build_model(self, input_shape):
|
||||
x = Input(shape=(32, 32, 3))
|
||||
y = x
|
||||
y = Convolution2D(
|
||||
filters=64,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal",
|
||||
)(y)
|
||||
y = Convolution2D(
|
||||
filters=64,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal",
|
||||
)(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Convolution2D(
|
||||
filters=128,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal",
|
||||
)(y)
|
||||
y = Convolution2D(
|
||||
filters=128,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal",
|
||||
)(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Convolution2D(
|
||||
filters=256,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal",
|
||||
)(y)
|
||||
y = Convolution2D(
|
||||
filters=256,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal",
|
||||
)(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Flatten()(y)
|
||||
y = Dropout(self.config.get("dropout", 0.5))(y)
|
||||
y = Dense(units=10, activation="softmax", kernel_initializer="he_normal")(y)
|
||||
|
||||
model = Model(inputs=x, outputs=y, name="model1")
|
||||
return model
|
||||
|
||||
def setup(self, config):
|
||||
self.train_data, self.test_data = self._read_data()
|
||||
x_train = self.train_data[0]
|
||||
model = self._build_model(x_train.shape[1:])
|
||||
|
||||
opt = tf.keras.optimizers.Adadelta(
|
||||
lr=self.config.get("lr", 1e-4), weight_decay=self.config.get("decay", 1e-4)
|
||||
)
|
||||
model.compile(
|
||||
loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
|
||||
)
|
||||
self.model = model
|
||||
|
||||
def step(self):
|
||||
x_train, y_train = self.train_data
|
||||
x_train, y_train = x_train[:NUM_SAMPLES], y_train[:NUM_SAMPLES]
|
||||
x_test, y_test = self.test_data
|
||||
x_test, y_test = x_test[:NUM_SAMPLES], y_test[:NUM_SAMPLES]
|
||||
|
||||
aug_gen = ImageDataGenerator(
|
||||
# set input mean to 0 over the dataset
|
||||
featurewise_center=False,
|
||||
# set each sample mean to 0
|
||||
samplewise_center=False,
|
||||
# divide inputs by dataset std
|
||||
featurewise_std_normalization=False,
|
||||
# divide each input by its std
|
||||
samplewise_std_normalization=False,
|
||||
# apply ZCA whitening
|
||||
zca_whitening=False,
|
||||
# randomly rotate images in the range (degrees, 0 to 180)
|
||||
rotation_range=0,
|
||||
# randomly shift images horizontally (fraction of total width)
|
||||
width_shift_range=0.1,
|
||||
# randomly shift images vertically (fraction of total height)
|
||||
height_shift_range=0.1,
|
||||
# randomly flip images
|
||||
horizontal_flip=True,
|
||||
# randomly flip images
|
||||
vertical_flip=False,
|
||||
)
|
||||
|
||||
aug_gen.fit(x_train)
|
||||
batch_size = self.config.get("batch_size", 64)
|
||||
gen = aug_gen.flow(x_train, y_train, batch_size=batch_size)
|
||||
self.model.fit_generator(
|
||||
generator=gen, epochs=self.config.get("epochs", 1), validation_data=None
|
||||
)
|
||||
|
||||
# loss, accuracy
|
||||
_, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
|
||||
return {"mean_accuracy": accuracy}
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
file_path = checkpoint_dir + "/model"
|
||||
self.model.save(file_path)
|
||||
|
||||
def load_checkpoint(self, checkpoint_dir):
|
||||
# See https://stackoverflow.com/a/42763323
|
||||
del self.model
|
||||
file_path = checkpoint_dir + "/model"
|
||||
self.model = load_model(file_path)
|
||||
|
||||
def cleanup(self):
|
||||
# If need, save your model when exit.
|
||||
# saved_path = self.model.save(self.logdir)
|
||||
# print("save model at: ", saved_path)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing"
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
space = {
|
||||
"epochs": 1,
|
||||
"batch_size": 64,
|
||||
"lr": tune.grid_search([10**-4, 10**-5]),
|
||||
"decay": tune.sample_from(lambda spec: spec.config.lr / 100.0),
|
||||
"dropout": tune.grid_search([0.25, 0.5]),
|
||||
}
|
||||
if args.smoke_test:
|
||||
space["lr"] = 10**-4
|
||||
space["dropout"] = 0.5
|
||||
|
||||
perturbation_interval = 10
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
perturbation_interval=perturbation_interval,
|
||||
hyperparam_mutations={
|
||||
"dropout": lambda _: np.random.uniform(0, 1),
|
||||
},
|
||||
)
|
||||
|
||||
tuner = tune.Tuner(
|
||||
tune.with_resources(
|
||||
Cifar10Model,
|
||||
resources={"cpu": 1, "gpu": 1},
|
||||
),
|
||||
run_config=train.RunConfig(
|
||||
name="pbt_cifar10",
|
||||
stop={
|
||||
"mean_accuracy": 0.80,
|
||||
"training_iteration": 30,
|
||||
},
|
||||
checkpoint_config=train.CheckpointConfig(
|
||||
checkpoint_frequency=perturbation_interval,
|
||||
checkpoint_score_attribute="mean_accuracy",
|
||||
num_to_keep=2,
|
||||
),
|
||||
),
|
||||
tune_config=tune.TuneConfig(
|
||||
scheduler=pbt,
|
||||
num_samples=4,
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
reuse_actors=True,
|
||||
),
|
||||
param_space=space,
|
||||
)
|
||||
results = tuner.fit()
|
||||
print("Best hyperparameters found were: ", results.get_best_result().config)
|
Loading…
Reference in a new issue