runner.Trainer
December 14, 2023 ยท View on GitHub
View source
on GitHub
A class for training and validation of a Keras model.
Attributes | |
|---|---|
model_dir | |
strategy | |
Methods
train
@abc.abstractmethodtrain( model_fn: Callable[[], tf.keras.Model], train_ds_provider: DatasetProvider, *, epochs: int = 1, valid_ds_provider: Optional[DatasetProvider] = None ) -> tf.keras.Model
Trains a tf.keras.Model with optional validation.
| Args | |
|---|---|
model_fn
|
Returns a tf.keras.Model for use in training and validation.
|
train_ds_provider
|
A DatasetProvider for training. The items of the
tf.data.Dataset are pairs (graph_tensor, label) that represent one
batch of per-replica training inputs after
GraphTensor.merge_batch_to_components() has been applied.
|
epochs
|
The epochs to train. |
valid_ds_provider
|
A DatasetProvider for validation. The items of the
tf.data.Dataset are pairs (graph_tensor, label) that represent one
batch of per-replica training inputs after
GraphTensor.merge_batch_to_components() has been applied.
|
| Returns | |
|---|---|
A trained tf.keras.Model.
|