runner.Task
December 14, 2023 · View on GitHub
View source
on GitHub
Defines a learning objective for a GNN.
A Task represents a learning objective for a GNN model and defines all the
non-GNN pieces around the base GNN. Specifically:
preprocessis expected to return aGraphTensor(orGraphTensors) and aFieldwhere (a) the base GNN's output for eachGraphTensoris passed topredictand (b) theFieldis used as the training label (for supervised tasks); 2)predictis expected to (a) take the base GNN's output for eachGraphTensorreturned bypreprocessand (b) return a tensor with the model's prediction for this task; 3)lossesis expected to return callables (tf.Tensor,tf.Tensor) ->tf.Tensorthat accept (y_true,y_pred) wherey_trueis produced by some dataset andy_predis the model's prediction from (2); 4)metricsis expected to return callables (tf.Tensor,tf.Tensor) ->tf.Tensorthat accept (y_true,y_pred) wherey_trueis produced by some dataset andy_predis the model's prediction from (2).
Task can emit multiple outputs in predict: in that case we require that (a)
it is a mapping, (b) outputs of losses and metrics are also mappings with
matching keys, and (c) there is exactly one loss per key (there may be a
sequence of metrics per key). This is done to prevent accidental dropping of
losses (see b/291874188).
No constraints are made on the predict method; e.g.: it may append a head with
learnable weights or it may perform tensor computations only. (The entire Task
coordinates what that means with respect to dataset—via preprocess—,
modeling—via predict— and optimization—via losses.)
Tasks are applied in the scope of a training invocation: they are subject to
the executing context of the Trainer and should, when needed, override it
(e.g., a global policy, like tf.keras.mixed_precision.global_policy() and its
implications over logit and activation layers).
Methods
losses
@abc.abstractmethodlosses() -> Losses
Returns arbitrary task specific losses.
metrics
@abc.abstractmethodmetrics() -> Metrics
Returns arbitrary task specific metrics.
predict
@abc.abstractmethodpredict( *args ) -> Predictions
Produces prediction outputs for the learning objective.
Overall model composition* makes use of the Keras Functional API
(https://www.tensorflow.org/guide/keras/functional) to map symbolic Keras
GraphTensor inputs to symbolic Keras Field outputs. Outputs must match the
structure (one or mapping) of labels from preprocess.
*) outputs = predict(GNN(inputs)) where inputs are those GraphTensor
returned by preprocess(...), GNN is the base GNN, predict is this method
and outputs are the prediction outputs for the learning objective.
| Args | |
|---|---|
*args
|
The symbolic Keras GraphTensor inputs(s). These inputs correspond
(in sequence) to the base GNN output of each GraphTensor returned by
preprocess(...).
|
| Returns | |
|---|---|
| The model's prediction output for this task. |
preprocess
@abc.abstractmethodpreprocess( inputs: GraphTensor ) -> tuple[OneOrSequenceOf[GraphTensor], OneOrMappingOf[Field]]
Preprocesses a scalar (after merge_batch_to_components) GraphTensor.
This function uses the Keras functional API to define non-trainable
transformations of the symbolic input GraphTensor, which get executed during
dataset preprocessing in a tf.data.Dataset.map(...) operation. It has two
responsibilities:
- Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
- Optionally, transforming input features. Some advanced modeling techniques
require running the same base GNN on multiple different transformations, so
this function may return a single
GraphTensoror a non-empty sequence ofGraphTensors. The corresponding base GNN output for eachGraphTensoris provided to thepredict(...)method.
| Args | |
|---|---|
inputs
|
A symbolic Keras GraphTensor for processing.
|
| Returns | |
|---|---|
A tuple of processed GraphTensor(s) and a (one or mapping of) Field to
be used as labels.
|