tfgnn.NodeSet
December 14, 2023 ยท View on GitHub
View source
on GitHub
A composite tensor for node set features plus size information.
tfgnn.NodeSet(
data: Data, spec: 'GraphPieceSpecBase'
)
The items of the node set are subset of graph nodes.
All nodes in a node set have the same features, identified by a string key.
Each feature is stored as one tensor and has shape [*graph_shape, num_nodes, *feature_shape]. The num_nodes is the number of nodes in a graph (could be
ragged). The feature_shape is the shape of the feature value for each node.
NodeSet supports both fixed-size and variable-size features. The fixed-size
features must have fully defined feature_shape. They are stored as tf.Tensor
if num_nodes is fixed-size or graph_shape.rank = 0. Variable-size node
features are always stored as tf.RaggedTensor.
Note that node set features are indexed without regard to graph components.
The information which node belong to which graph component is contained in
the .sizes tensor which defines the number of nodes in each graph component.
Args | |
|---|---|
data
|
Nest of Field or subclasses of GraphPieceBase. |
spec
|
A subclass of GraphPieceSpecBase with a _data_spec that matches
data.
|
Methods
from_fields
@classmethodfrom_fields( *_, features: Optional[Fields] = None, sizes: Field, validate: Optional[bool] = None ) -> 'NodeSet'
Constructs a new instance from node set fields.
Example:
tfgnn.NodeSet.from_fields(
sizes=tf.constant([3]),
features={
"tokenized_title": tf.ragged.constant(
[["Anisotropic", "approximation"],
["Better", "bipartite", "bijection", "bounds"],
["Convolutional", "convergence", "criteria"]]),
"embedding": tf.zeros([3, 128]),
"year": tf.constant([2018, 2019, 2020]),
})
| Args | |
|---|---|
features
|
A mapping from feature name to feature Tensors or RaggedTensors.
All feature tensors must have shape [*graph_shape, num_nodes,
*feature_shape], where num_nodes is the number of nodes in the node
set (could be ragged) and feature_shape is a shape of the feature value
for each node.
|
sizes
|
A number of nodes in each graph component. Has shape
[*graph_shape, num_components], where num_components is the number
of graph components (could be ragged).
|
validate
|
If true, use tf.assert ops to inspect the shapes of each field
and check at runtime that they form a valid NodeSet. The default
behavior is set by the disable_graph_tensor_validation_at_runtime()
and enable_graph_tensor_validation_at_runtime().
|
| Returns | |
|---|---|
A NodeSet composite tensor.
|
get_features_dict
get_features_dict() -> Dict[FieldName, Field]
Returns features copy as a dictionary.
replace_features
replace_features(
features: Mapping[FieldName, Field]
) -> '_NodeOrEdgeSet'
Returns a new instance with a new set of features.
set_shape
set_shape(
new_shape: ShapeLike
) -> 'GraphPieceBase'
Deprecated. Use with_shape().
with_indices_dtype
with_indices_dtype(
dtype: tf.dtypes.DType
) -> 'GraphPieceBase'
Returns a copy of this piece with the given indices dtype.
with_row_splits_dtype
with_row_splits_dtype(
dtype: tf.dtypes.DType
) -> 'GraphPieceBase'
Returns a copy of this piece with the given row splits dtype.
with_shape
with_shape(
new_shape: ShapeLike
) -> 'GraphPieceBase'
Enforce the common prefix shape on all the contained features.
__getitem__
__getitem__(
feature_name: FieldName
) -> Field
Indexing operator [] to access feature values by their name.