Skip to content

Fix GraphUpdate from_config for nested serialized layers (improves Keras / tfmot compatibility)#923

Open
ur-miya wants to merge 1 commit into
tensorflow:mainfrom
ur-miya:fix-gnn-from-config-robustness
Open

Fix GraphUpdate from_config for nested serialized layers (improves Keras / tfmot compatibility)#923
ur-miya wants to merge 1 commit into
tensorflow:mainfrom
ur-miya:fix-gnn-from-config-robustness

Conversation

@ur-miya

@ur-miya ur-miya commented Jun 3, 2026

Copy link
Copy Markdown

Summary

This PR makes several TF‑GNN Keras layers (GraphUpdate, EdgeSetUpdate, NodeSetUpdate, ContextUpdate) more robust when reconstructed from serialized configs, especially in setups where tooling such as tfmot.quantization.keras serializes nested layers as config dicts instead of already-instantiated Layer objects.

The constructors and _check_is_layer(...) contract are unchanged; only from_config(...) paths are made tolerant to seeing serialized sub‑layer configs and normalize them back to tf.keras.layers.Layer instances before calling __init__.


Problem

In some environments (e.g. TF 2.16 + tf-keras + tfmot QAT), the following pattern can occur:

  • A model containing TF‑GNN layers (e.g. GraphUpdate) is cloned/serialized via Keras (clone_model, model.to_json, etc.).
  • Nested sub‑layers (e.g. EdgeSetUpdate, NodeSetUpdate.next_state, context/node/edge inputs) are stored in the config as dicts of the form {class_name, config, ...} rather than as Layer instances.
  • Later, Keras calls GraphUpdate.from_config(config) / EdgeSetUpdate.from_config(config). The corresponding __init__ implementations still expect actual Layer objects and call _check_is_layer(...), leading to errors like:

GraphUpdate(edge_sets={edge: ...}) must be a tf.keras.layer.Layer, got type: dict
EdgeSetUpdate(next_state=...) must be a tf.keras.layer.Layer, got type: dict.

This makes it difficult to combine TF‑GNN models with tooling that relies on cloning and graph transformations via Keras configs (e.g. QAT with tfmot).


Approach

The idea is to leave _check_is_layer(...) and the type guarantees in the constructors unchanged, and instead make from_config(...) smart enough to recognize serialized Keras layers and deserialize them back into Layer instances.

  1. Introduce a small helper:
def _maybe_deserialize_layer(obj):
  if isinstance(obj, tf.keras.layers.Layer):
    return obj
  if isinstance(obj, dict) and "class_name" in obj and "config" in obj:
    return tf.keras.layers.deserialize(obj)
  return obj
  1. Update the from_config(...) methods:
  • GraphUpdate

    • Keep the existing du.pop_by_prefix logic:

      config["edge_sets"] = du.pop_by_prefix(config, "edge_sets/")
      config["node_sets"] = du.pop_by_prefix(config, "node_sets/")
    • After that, run _maybe_deserialize_layer on:

      • each value in config["edge_sets"] and config["node_sets"];
      • config["context"] (if present).
  • EdgeSetUpdate

    • Add a custom from_config that:
      • copies the config;
      • applies _maybe_deserialize_layer to config["next_state"] (if present);
      • calls cls(**config).
  • NodeSetUpdate

    • Extend from_config to:
      • call du.pop_by_prefix(config, "edge_set_inputs/") as before;
      • run _maybe_deserialize_layer on each edge_set_inputs[...];
      • run _maybe_deserialize_layer on next_state (if present).
  • ContextUpdate

    • Extend from_config to:
      • call du.pop_by_prefix on node_set_inputs/* and edge_set_inputs/*;
      • run _maybe_deserialize_layer on each node_set_inputs[...] and edge_set_inputs[...];
      • run _maybe_deserialize_layer on next_state (if present).

In all cases, _check_is_layer(...) in the constructors is unchanged and will still reject non‑Layer values that could not be normalized.


Rationale / why this is safe

  • We do not change runtime call semantics or relax any type checks. All constructors still require actual tf.keras.layers.Layer instances and _check_is_layer(...) continues to enforce that.
  • Only from_config(...) code paths are touched, i.e. reconstruction from serialized configs (load_model, clone_model, etc.), not ordinary model building.
  • _maybe_deserialize_layer(...) is conservative:
    • if obj is already a Layer, it is returned as‑is;
    • if obj is a dict with class_name and config, it matches the standard Keras serialization format and is deserialized via tf.keras.layers.deserialize;
    • in all other cases obj is left unchanged and _check_is_layer(...) will still raise if it is not a Layer.

The pattern is applied consistently across GraphUpdate, EdgeSetUpdate, NodeSetUpdate and ContextUpdate, which makes future maintenance and reasoning about serialization behavior easier.


Tests

  • New tests in tensorflow_gnn/keras/layers/graph_update_test.py:

    • GraphUpdateSerializationTest.test_graph_update_from_config_deserializes_nested_layers

      • Constructs a GraphUpdate with nested EdgeSetUpdate / NodeSetUpdate using TF‑GNN layers (including Pool and NextStateFromConcat), calls get_config() and GraphUpdate.from_config(config), and asserts the rebuilt instance is a GraphUpdate whose get_config() still contains keys like "edge_sets/edge" and "node_sets/node".
    • GraphUpdateSerializationTest.test_edge_set_update_from_config_deserializes_next_state

      • Verifies that EdgeSetUpdate.from_config(...) correctly restores next_state and that get_config() on the rebuilt layer includes "next_state".
    • GraphUpdateSerializationTest.test_node_set_update_from_config_deserializes_nested_layers

      • Verifies that NodeSetUpdate.from_config(...) correctly restores edge_set_inputs and next_state, and that get_config() on the rebuilt layer includes "edge_set_inputs/edge" and "next_state".
  • Existing tests in graph_update_test.py continue to pass.

  • All tests were run under a fresh virtualenv with TensorFlow 2.16+, tensorflow-gnn 1.0.3 and tf-keras, with:

    TF_USE_LEGACY_KERAS=1 python -m tensorflow_gnn.keras.layers.graph_update_test

@google-cla

google-cla Bot commented Jun 3, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@arnoegw

arnoegw commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Hi @ur-miya , thank you for sending a contribution to TensorFlow GNN!

Your problem analysis is spot-on: As of TF/Keras version 2.13, (de)serialization to/from JSON configs has changed as part of the switch to the new save_format="keras" (transiently known as "keras_v3"), including the problematic behavior you describe. This was a regression behind TF/Keras 2.9+, which had purposefully fixed that problem with commit keras-team/keras@5caa668. We have privately reported this regression to the Keras team early on in 2023, to no avail. As a result, TF-GNN 1.0 has documented this way of serialization to be "unsupported and broken".

Before we go any further: please be aware that, at this point, there is no known timeline for a next release of TF-GNN. That means, merging your pull request will not directly help users that rely on pip install tensorflow_gnn. - If that's a deal breaker for you, let's save each other's time and stop early, because I see some open questions about your change as well:

How does (de)serialization work for the layers passed into the various next_state args? The library provides a few standard ones at tensorflow_gnn/keras/layers/next_state.py. While they don't explicitly call a helper like _check_is_layer(), they won't work without an inner Layer when called. Conspicuously, the revised unit test in this PR just checks initialization of the Layer classes; I believe it should also check calling them.

More esoterically, there is also the question of sub-objects that are shared between higher-level layers to express weight sharing: Suppose two next_state objects share the same instance of tf.keras.layer.Dense. That is possible and (at least potentially) useful to share their trained weights between different graph pieces and/or successive graph updates, which matters for training behavior and for model size. Will repeated calls to tf.keras.layers.deserialize(config) return the same or distinct but equal objects? I'm afraid not (but I could be wrong). It's not for TF-GNN to start replicating missing Keras features, but at least then TF-GNN should very loudly document that round-tripping through serialization would subtly but severely change the architecture of such a model (or not, sneakily, depending on the Keras version installed).

How would you like to proceed?

@arnoegw arnoegw self-assigned this Jun 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants