tfgnn.Context
December 14, 2023 ยท View on GitHub
View source
on GitHub
A composite tensor for graph context features.
tfgnn.Context(
data: Data, spec: 'GraphPieceSpecBase'
)
The items of the context are the graph components (just like the items of a
node set are the nodes and the items of an edge set are the edges). The
Context is a composite tensor. It stores features that belong to a graph
component as a whole, not any particular node or edge. Each context feature
has a shape [*graph_shape, num_components, ...], where num_components is
the number of graph components in a graph (could be ragged).
Args | |
|---|---|
data
|
Nest of Field or subclasses of GraphPieceBase. |
spec
|
A subclass of GraphPieceSpecBase with a _data_spec that matches
data.
|
Methods
from_fields
@classmethodfrom_fields( *_, features: Optional[Fields] = None, sizes: Optional[Field] = None, shape: Optional[ShapeLike] = None, indices_dtype: Optional[tf.dtypes.DType] = None, validate: Optional[bool] = None ) -> 'Context'
Constructs a new instance from context fields.
Example:
tfgnn.Context.from_fields(features={'country_code': ['CH']})
| Args | |
|---|---|
features
|
A mapping from feature name to feature Tensor or RaggedTensor.
All feature tensors must have shape [*graph_shape, num_components,
*feature_shape], where num_components is the number of graph
components (could be ragged); feature_shape are feature-specific
dimensions (could be ragged).
|
sizes
|
A Tensor of 1's with shape [*graph_shape, num_components], where
num_components is the number of graph components (could be ragged).
For symmetry with sizes in NodeSet and EdgeSet, this counts the items
per graph component, but since the items of Context are the components
themselves, each value is 1. Must be compatible with shape, if that is
specified.
|
shape
|
The shape of this tensor and a GraphTensor containing it, also
known as the graph_shape. If not specified, the shape is inferred from
sizes or set to [] if the sizes is not specified.
|
indices_dtype
|
An indices_dtype of a GraphTensor containing this object,
used as row_splits_dtype when batching potentially ragged fields. If
sizes are specified they are casted to that type.
|
validate
|
If true, use tf.assert ops to inspect the shapes of each field
and check at runtime that they form a valid Context. The default
behavior is set by the disable_graph_tensor_validation_at_runtime()
and enable_graph_tensor_validation_at_runtime().
|
| Returns | |
|---|---|
A Context composite tensor.
|
get_features_dict
get_features_dict() -> Dict[FieldName, Field]
Returns features copy as a dictionary.
replace_features
replace_features(
features: Fields
) -> 'Context'
Returns a new instance with a new set of features.
set_shape
set_shape(
new_shape: ShapeLike
) -> 'GraphPieceBase'
Deprecated. Use with_shape().
with_indices_dtype
with_indices_dtype(
dtype: tf.dtypes.DType
) -> 'GraphPieceBase'
Returns a copy of this piece with the given indices dtype.
with_row_splits_dtype
with_row_splits_dtype(
dtype: tf.dtypes.DType
) -> 'GraphPieceBase'
Returns a copy of this piece with the given row splits dtype.
with_shape
with_shape(
new_shape: ShapeLike
) -> 'GraphPieceBase'
Enforce the common prefix shape on all the contained features.
__getitem__
__getitem__(
feature_name: FieldName
) -> Field
Indexing operator [] to access feature values by their name.