runner.Trainer

December 14, 2023 ยท View on GitHub

View source on GitHub

A class for training and validation of a Keras model.

model_dir
strategy

Methods

train

View source

Trains a tf.keras.Model with optional validation.

Args
model_fn Returns a tf.keras.Model for use in training and validation.
train_ds_provider A DatasetProvider for training. The items of the tf.data.Dataset are pairs (graph_tensor, label) that represent one batch of per-replica training inputs after GraphTensor.merge_batch_to_components() has been applied.
epochs The epochs to train.
valid_ds_provider A DatasetProvider for validation. The items of the tf.data.Dataset are pairs (graph_tensor, label) that represent one batch of per-replica training inputs after GraphTensor.merge_batch_to_components() has been applied.
Returns
A trained tf.keras.Model.