runner.RootNodeMulticlassClassification

December 14, 2023 ยท View on GitHub

View source on GitHub

Root node multiclass classification.

Inherits From: Task

node_set_name The node set containing the root node.
num_classes The number of classes. Exactly one of num_classes or class_names must be specified
class_names The class names. Exactly one of num_classes or class_names must be specified
per_class_statistics Whether to compute statistics per class.
state_name The feature name for activations (e.g.: tfgnn.HIDDEN_STATE).
name The classification head's layer name. To control the naming of saved model outputs see the runner model exporters (e.g., KerasModelExporter).
label_fn A label extraction function. This function mutates the input GraphTensor. Mutually exclusive with label_feature_name.
label_feature_name A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input GraphTensor. Mutually exclusive with label_fn.

Methods

gather_activations

View source

Gather activations from root nodes.

losses

View source

Sparse categorical crossentropy loss.

metrics

View source

Sparse categorical metrics.

predict

View source

Apply a linear head for classification.

Args
inputs A tfgnn.GraphTensor for classification.
Returns
The classification logits.

preprocess

View source

Preprocesses a scalar (after merge_batch_to_components) GraphTensor.

This function uses the Keras functional API to define non-trainable transformations of the symbolic input GraphTensor, which get executed during dataset preprocessing in a tf.data.Dataset.map(...) operation. It has two responsibilities:

  1. Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
  2. Optionally, transforming input features. Some advanced modeling techniques require running the same base GNN on multiple different transformations, so this function may return a single GraphTensor or a non-empty sequence of GraphTensors. The corresponding base GNN output for each GraphTensor is provided to the predict(...) method.
Args
inputs A symbolic Keras GraphTensor for processing.
Returns
A tuple of processed GraphTensor(s) and a (one or mapping of) Field to be used as labels.