From b98904d9edafbe484a1b8af343ef45110cf64364 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Thu, 2 Jul 2026 03:24:18 -0700 Subject: [PATCH] Automated Code Change PiperOrigin-RevId: 941594773 --- dgf/src/io/cache.py | 2 +- dgf/src/io/graph_in_beam.py | 14 ++++++------- dgf/src/io/graph_in_memory.py | 8 ++++---- dgf/src/io/hgraph_in_avro.py | 14 ++++++------- dgf/src/io/hgraph_in_avro_test.py | 14 ++++++------- dgf/src/io/hgraph_in_beam.py | 4 ++-- dgf/src/io/hgraph_in_memory.py | 22 ++++++++++---------- dgf/src/io/jax.py | 2 +- dgf/src/io/parquet.py | 2 +- dgf/src/io/schema.py | 4 ++-- dgf/src/io/spanner.py | 32 +++++++++++++++--------------- dgf/src/io/tf_graph_sample.py | 8 ++++---- dgf/src/io/tf_graph_sample_test.py | 2 +- dgf/src/io/tfexample.py | 4 ++-- 14 files changed, 66 insertions(+), 66 deletions(-) diff --git a/dgf/src/io/cache.py b/dgf/src/io/cache.py index 2661b8a..2cb1b77 100644 --- a/dgf/src/io/cache.py +++ b/dgf/src/io/cache.py @@ -123,7 +123,7 @@ def cache( if len(found_vars) == len(variable_names): if return_tuple: - return tuple(found_vars) + return tuple(found_vars) # pyrefly: ignore[bad-return] else: return found_vars[0] diff --git a/dgf/src/io/graph_in_beam.py b/dgf/src/io/graph_in_beam.py index 57f9be4..7cccf7b 100644 --- a/dgf/src/io/graph_in_beam.py +++ b/dgf/src/io/graph_in_beam.py @@ -104,7 +104,7 @@ def read_graph( # Metadata with filesystem.open_read(os.path.join(path, FILENAME_METADATA)) as f: - metadata = gf_metadata_lib.GFGraphMetadata.from_json(f.read()) + metadata = gf_metadata_lib.GFGraphMetadata.from_json(f.read()) # pyrefly: ignore[missing-attribute] if metadata.version > MAX_SUPPORTED_GF_VERSION: raise NotImplementedError( @@ -289,7 +289,7 @@ def write_graph( metadata = gf_metadata_lib.GFGraphMetadata(version=MAX_SUPPORTED_GF_VERSION) metadata_path = os.path.join(path, FILENAME_METADATA) with filesystem.open_write(metadata_path) as f: - f.write(metadata.to_json(indent=2)) + f.write(metadata.to_json(indent=2)) # pyrefly: ignore[missing-attribute] write_results = [] @@ -346,7 +346,7 @@ def _feature_schema_to_parquet_fields( """Creates the schema for the parquet node container.""" fields = [] # Note: The schema has the node "#id". - for feature_name, feature_schema in feature_schema.items(): + for feature_name, feature_schema in feature_schema.items(): # pyrefly: ignore[missing-attribute] pa_type = FEATURE_FORMAT_TO_PY_ARROW_DTYPE[feature_schema.format] shape = feature_schema.shape if shape is None: @@ -364,7 +364,7 @@ def _node_schema_to_parquet_schema( node_schema: schema_lib.NodeSchema, ) -> pyarrow.Schema: """Creates the schema for the parquet node container.""" - return pyarrow.schema(_feature_schema_to_parquet_fields(node_schema.features)) + return pyarrow.schema(_feature_schema_to_parquet_fields(node_schema.features)) # pyrefly: ignore[bad-argument-type] def _edge_schema_to_parquet_schema( @@ -397,7 +397,7 @@ def _edge_schema_to_parquet_schema( pyarrow.field( KEY_TARGET, FEATURE_FORMAT_TO_PY_ARROW_DTYPE[target_node_format] ), - ] + _feature_schema_to_parquet_fields(edge_schema.features) + ] + _feature_schema_to_parquet_fields(edge_schema.features) # pyrefly: ignore[bad-argument-type] return pyarrow.schema(fields) @@ -411,7 +411,7 @@ def _node_to_raw( if feature_name == primary_key: raw_dict[feature_name] = node.id else: - feature_values = node.features[feature_name] + feature_values = node.features[feature_name] # pyrefly: ignore[unsupported-operation] raw_dict[feature_name] = feature_values.tolist() return raw_dict @@ -425,6 +425,6 @@ def _edge_to_raw( KEY_TARGET: edge.target, } for feature_name in schema.features: - feature_values = edge.features[feature_name] + feature_values = edge.features[feature_name] # pyrefly: ignore[unsupported-operation] raw_dict[feature_name] = feature_values.tolist() return raw_dict diff --git a/dgf/src/io/graph_in_memory.py b/dgf/src/io/graph_in_memory.py index 46bdee7..1e028c2 100644 --- a/dgf/src/io/graph_in_memory.py +++ b/dgf/src/io/graph_in_memory.py @@ -171,7 +171,7 @@ def _read_edge_set( target_mapper, source_ids, target_ids, - min(32, os.cpu_count()), + min(32, os.cpu_count()), # pyrefly: ignore[bad-specialization] ) else: # Slow path @@ -280,7 +280,7 @@ def read_graph( with filesystem.open_read(os.path.join(path, FILENAME_METADATA)) as f: if verbose: log.info("Reading metadata from %s", path) - metadata = gf_metadata_lib.GFGraphMetadata.from_json(f.read()) + metadata = gf_metadata_lib.GFGraphMetadata.from_json(f.read()) # pyrefly: ignore[missing-attribute] if metadata.version > MAX_SUPPORTED_GF_VERSION: raise NotImplementedError( @@ -376,7 +376,7 @@ def write_graph( if verbose: log.info("Writing metadata to %s", metadata_path) with filesystem.open_write(metadata_path) as f: - f.write(metadata.to_json(indent=2)) + f.write(metadata.to_json(indent=2)) # pyrefly: ignore[missing-attribute] # Write Node Sets node_dir = os.path.join(path, FILENAME_NODE_FEATURE) @@ -390,7 +390,7 @@ def write_graph( if verbose: log.info("Writing nodeset %s to %s", nodeset_name, node_dir) - num_shards, _ = shard_lib.estimate_num_node_shards(node_set.num_nodes) + num_shards, _ = shard_lib.estimate_num_node_shards(node_set.num_nodes) # pyrefly: ignore[bad-argument-type] if max_num_shards is not None: num_shards = min(num_shards, max_num_shards) diff --git a/dgf/src/io/hgraph_in_avro.py b/dgf/src/io/hgraph_in_avro.py index c509160..0e3715d 100644 --- a/dgf/src/io/hgraph_in_avro.py +++ b/dgf/src/io/hgraph_in_avro.py @@ -82,7 +82,7 @@ def _generate_node_records( """Generates records for writing node sets to Avro.""" iterator = range(start_index, end_index) if verbose: - iterator = tqdm( + iterator = tqdm( # pyrefly: ignore[not-callable] iterator, desc=f" - Writing nodes for '{name}'", unit="node", @@ -109,7 +109,7 @@ def _generate_edge_records( """Generates records for writing edge sets to Avro.""" iterator = range(start_index, end_index) if verbose: - iterator = tqdm( + iterator = tqdm( # pyrefly: ignore[not-callable] iterator, desc=f" - Writing edges for '{name}'", unit="edge", @@ -152,7 +152,7 @@ def write_avro_node_sets( parsed_schema = fastavro.parse_schema(avro_schema_dict) feature_items = list(nodeset.features.items()) num_shards, num_nodes_per_shard = shard_lib.estimate_num_node_shards( - nodeset.num_nodes + nodeset.num_nodes # pyrefly: ignore[bad-argument-type] ) for shard_index in range(num_shards): filename = shard_lib.sharded_filename( @@ -163,7 +163,7 @@ def write_avro_node_sets( ) filepath = os.path.join(directory, filename) start_index = shard_index * num_nodes_per_shard - end_index = min( + end_index = min( # pyrefly: ignore[bad-specialization] (shard_index + 1) * num_nodes_per_shard, nodeset.num_nodes ) with filesystem.open_write(filepath, binary=True) as f_out: @@ -173,7 +173,7 @@ def write_avro_node_sets( _generate_node_records( feature_items, start_index, - end_index, + end_index, # pyrefly: ignore[bad-argument-type] nodeset_name, verbose, ), @@ -280,7 +280,7 @@ def read_avro_record( reader = fastavro_reader(f_in) record_iterator = reader if verbose: - record_iterator = tqdm( + record_iterator = tqdm( # pyrefly: ignore[not-callable] reader, desc=f" - Reading records from {avro_file}", unit="record", @@ -288,7 +288,7 @@ def read_avro_record( for record in record_iterator: num_records += 1 for feature_name in feature_builders.keys(): - feature_builders[feature_name].append(record[feature_name]) + feature_builders[feature_name].append(record[feature_name]) # pyrefly: ignore[bad-index, unsupported-operation] # Convert lists to numpy arrays final_features = {} diff --git a/dgf/src/io/hgraph_in_avro_test.py b/dgf/src/io/hgraph_in_avro_test.py index 8739b42..0633d1d 100644 --- a/dgf/src/io/hgraph_in_avro_test.py +++ b/dgf/src/io/hgraph_in_avro_test.py @@ -148,7 +148,7 @@ def test_read_avro_record(self): "f1": ("int64", ()), "f2": ("float32", ()), } - data, num_records = avro_lib.read_avro_record([path], columns, False) + data, num_records = avro_lib.read_avro_record([path], columns, False) # pyrefly: ignore[bad-argument-type] self.assertEqual(num_records, 2) np.testing.assert_array_equal(data["f1"], np.array([1, 2], dtype="int64")) @@ -175,7 +175,7 @@ def test_read_avro_record_sharded(self): columns = {"f1": ("int64", ())} data, num_records = avro_lib.read_avro_record( - [path1, path2], columns, False + [path1, path2], columns, False # pyrefly: ignore[bad-argument-type] ) self.assertEqual(num_records, 3) @@ -194,7 +194,7 @@ def test_read_avro_record_empty(self): fastavro.writer(f, schema, records) columns = {"f1": ("int64", ())} - data, num_records = avro_lib.read_avro_record([path], columns, False) + data, num_records = avro_lib.read_avro_record([path], columns, False) # pyrefly: ignore[bad-argument-type] self.assertEqual(num_records, 0) self.assertEqual(data["f1"].shape, (0,)) @@ -221,7 +221,7 @@ def test_write_avro_node_sets(self): "f2": ("float32", (2,)), } n1_data, n1_num_records = avro_lib.read_avro_record( - [n1_path], n1_cols, False + [n1_path], n1_cols, False # pyrefly: ignore[bad-argument-type] ) self.assertEqual(n1_num_records, 2) np.testing.assert_array_equal(n1_data["#id"], np.array([b"1", b"2"])) @@ -237,7 +237,7 @@ def test_write_avro_node_sets(self): self.assertTrue(os.path.exists(n2_path)) n2_cols = {"#id": ("int64", ()), "f3": ("int64", ()), "f4": ("int64", ())} n2_data, n2_num_records = avro_lib.read_avro_record( - [n2_path], n2_cols, False + [n2_path], n2_cols, False # pyrefly: ignore[bad-argument-type] ) self.assertEqual(n2_num_records, 2) np.testing.assert_array_equal(n2_data["#id"], np.array([1, 2])) @@ -272,7 +272,7 @@ def test_write_avro_edge_sets(self): "#id": ("bytes", ()), } e1_data, e1_num_records = avro_lib.read_avro_record( - [e1_path], e1_cols, False + [e1_path], e1_cols, False # pyrefly: ignore[bad-argument-type] ) self.assertEqual(e1_num_records, 2) np.testing.assert_array_equal(e1_data["#source"], np.array([b"1", b"1"])) @@ -288,7 +288,7 @@ def test_write_avro_edge_sets(self): "#id": ("bytes", ()), } e2_data, e2_num_records = avro_lib.read_avro_record( - [e2_path], e2_cols, False + [e2_path], e2_cols, False # pyrefly: ignore[bad-argument-type] ) self.assertEqual(e2_num_records, 2) np.testing.assert_array_equal(e2_data["#source"], np.array([b"1", b"1"])) diff --git a/dgf/src/io/hgraph_in_beam.py b/dgf/src/io/hgraph_in_beam.py index b661533..81f8326 100644 --- a/dgf/src/io/hgraph_in_beam.py +++ b/dgf/src/io/hgraph_in_beam.py @@ -369,7 +369,7 @@ def tf_feature_to_feature( ) value = np.squeeze(value, axis=0) elif value.ndim != 1: - value = np.reshape(value, feature_schema.shape) + value = np.reshape(value, feature_schema.shape) # pyrefly: ignore[no-matching-overload] return value @@ -558,7 +558,7 @@ def node_to_tf_example( if feature_name == node_id_column: value = [node.id] else: - value = node.features[feature_name] + value = node.features[feature_name] # pyrefly: ignore[unsupported-operation] if value.ndim == 0: value = np.expand_dims(value, axis=0) diff --git a/dgf/src/io/hgraph_in_memory.py b/dgf/src/io/hgraph_in_memory.py index 811f954..53c21dc 100644 --- a/dgf/src/io/hgraph_in_memory.py +++ b/dgf/src/io/hgraph_in_memory.py @@ -334,11 +334,11 @@ def _read_container( """Reads features given a container type.""" if container_type == HGraphContainerType.TF_RECORD: features, num_records = tfexample_lib.read_tf_record( - paths=paths, columns=columns, verbose=verbose, preserve_order=False + paths=paths, columns=columns, verbose=verbose, preserve_order=False # pyrefly: ignore[bad-argument-type] ) elif container_type == HGraphContainerType.AVRO: features, num_records = hgraph_in_avro.read_avro_record( - paths=paths, columns=columns, verbose=verbose + paths=paths, columns=columns, verbose=verbose # pyrefly: ignore[bad-argument-type] ) else: raise ValueError( @@ -488,7 +488,7 @@ def mapper(ids: np.ndarray) -> Tuple[np.ndarray, int]: raw_edges, _ = _read_container( paths=paths, container_type=container_type, - columns=columns, + columns=columns, # pyrefly: ignore[bad-argument-type] verbose=verbose, key_column=edge_id_column, ) @@ -511,7 +511,7 @@ def mapper(ids: np.ndarray) -> Tuple[np.ndarray, int]: target_mapper, source_ids, target_ids, - min(32, os.cpu_count()), + min(32, os.cpu_count()), # pyrefly: ignore[bad-specialization] ) else: # Slow path @@ -629,16 +629,16 @@ def in_memory_node_to_tf_example( node_id_column is not None and node_id_column not in example.features.feature ): - if np.issubdtype(features[DEFAULT_KEY_ID].dtype, np.integer): + if np.issubdtype(features[DEFAULT_KEY_ID].dtype, np.integer): # pyrefly: ignore[unsupported-operation] example.features.feature[node_id_column].int64_list.value.append( - features[DEFAULT_KEY_ID][node_index] + features[DEFAULT_KEY_ID][node_index] # pyrefly: ignore[unsupported-operation] ) - elif features[DEFAULT_KEY_ID].dtype.kind == "S": + elif features[DEFAULT_KEY_ID].dtype.kind == "S": # pyrefly: ignore[unsupported-operation] example.features.feature[node_id_column].bytes_list.value.append( - features[DEFAULT_KEY_ID][node_index] + features[DEFAULT_KEY_ID][node_index] # pyrefly: ignore[unsupported-operation] ) else: - raise ValueError(f"Non supported type {features[DEFAULT_KEY_ID]}") + raise ValueError(f"Non supported type {features[DEFAULT_KEY_ID]}") # pyrefly: ignore[unsupported-operation] return example @@ -773,7 +773,7 @@ def _write_tfrecord_node_sets( """Writes node sets to TFRecord files.""" for nodeset_name, nodeset in graph.node_sets.items(): num_shards, num_nodes_per_shard = shard_lib.estimate_num_node_shards( - nodeset.num_nodes + nodeset.num_nodes # pyrefly: ignore[bad-argument-type] ) for shard_index in range(num_shards): examples = [] @@ -785,7 +785,7 @@ def _write_tfrecord_node_sets( ) for node_index in range( shard_index * num_nodes_per_shard, - min( + min( # pyrefly: ignore[bad-argument-type, bad-specialization] (shard_index + 1) * num_nodes_per_shard, nodeset.num_nodes, ), diff --git a/dgf/src/io/jax.py b/dgf/src/io/jax.py index ba22019..6a8cd0b 100644 --- a/dgf/src/io/jax.py +++ b/dgf/src/io/jax.py @@ -47,7 +47,7 @@ def _asarray(x): for node_set_name, node_set in src.node_sets.items(): jax_features = {k: _asarray(v) for k, v in node_set.features.items()} jax_node_sets[node_set_name] = jax_in_memory_graph_lib.JaxInMemoryNodeSet( - features=jax_features, num_nodes=node_set.num_nodes + features=jax_features, num_nodes=node_set.num_nodes # pyrefly: ignore[bad-argument-type] ) jax_edge_sets = {} diff --git a/dgf/src/io/parquet.py b/dgf/src/io/parquet.py index 27eab45..3ea9f54 100644 --- a/dgf/src/io/parquet.py +++ b/dgf/src/io/parquet.py @@ -34,7 +34,7 @@ @numba.njit(parallel=False) def _numba_copy_kernel(offset, offsets: np.ndarray, data, out, max_len): """Copies data from a flat array to a padded output array using Numba.""" - for i in numba.prange(len(offsets) - 1): + for i in numba.prange(len(offsets) - 1): # pyrefly: ignore[not-iterable] start = offsets[i] end = offsets[i + 1] length = end - start diff --git a/dgf/src/io/schema.py b/dgf/src/io/schema.py index 9dbf782..794bb3c 100644 --- a/dgf/src/io/schema.py +++ b/dgf/src/io/schema.py @@ -34,7 +34,7 @@ def read_schema(path: str) -> schema_lib.GraphSchema: The loaded graph schema. """ with filesystem.open_read(path) as f: - return schema_lib.GraphSchema.from_json(f.read()) + return schema_lib.GraphSchema.from_json(f.read()) # pyrefly: ignore[missing-attribute] def write_schema(schema: schema_lib.GraphSchema, path: str): @@ -52,4 +52,4 @@ def write_schema(schema: schema_lib.GraphSchema, path: str): path: Output path. """ with filesystem.open_write(path) as f: - f.write(schema.to_json(indent=2)) + f.write(schema.to_json(indent=2)) # pyrefly: ignore[missing-attribute] diff --git a/dgf/src/io/spanner.py b/dgf/src/io/spanner.py index fe5bc51..db2387b 100644 --- a/dgf/src/io/spanner.py +++ b/dgf/src/io/spanner.py @@ -193,7 +193,7 @@ def schema_to_spanner_ddl( Raises: ValueError: If the max_str_length is not a valid value. """ - max_bytes_length = max_bytes_length if max_bytes_length else "MAX" + max_bytes_length = max_bytes_length if max_bytes_length else "MAX" # pyrefly: ignore[bad-assignment] ddl_statements: Dict[str, str] = {} @@ -217,9 +217,9 @@ def schema_to_spanner_ddl( sql_type = feature_format_to_spanner_type( feature_schema.format, - feature_schema.shape, + feature_schema.shape, # pyrefly: ignore[bad-argument-type] max_bytes_length=max_bytes_length, - is_utf8_string=feature_schema.is_utf8_string, + is_utf8_string=feature_schema.is_utf8_string, # pyrefly: ignore[bad-argument-type] ) if is_id: feature_ddls.append(f"{feature_name} {sql_type} NOT NULL") @@ -254,9 +254,9 @@ def schema_to_spanner_ddl( source_type = feature_format_to_spanner_type( source_feature_id.format, - source_feature_id.shape, + source_feature_id.shape, # pyrefly: ignore[bad-argument-type] max_bytes_length=max_bytes_length, - is_utf8_string=source_feature_id.is_utf8_string, + is_utf8_string=source_feature_id.is_utf8_string, # pyrefly: ignore[bad-argument-type] ) tgt_node_schema = schema.node_sets[edge_set_schema.target] @@ -268,9 +268,9 @@ def schema_to_spanner_ddl( target_type = feature_format_to_spanner_type( target_feature_id.format, - target_feature_id.shape, + target_feature_id.shape, # pyrefly: ignore[bad-argument-type] max_bytes_length=max_bytes_length, - is_utf8_string=target_feature_id.is_utf8_string, + is_utf8_string=target_feature_id.is_utf8_string, # pyrefly: ignore[bad-argument-type] ) current_ddl_statement = ( @@ -297,8 +297,8 @@ def schema_to_spanner_ddl( sql_type = feature_format_to_spanner_type( feature_schema.format, - feature_schema.shape, - is_utf8_string=feature_schema.is_utf8_string, + feature_schema.shape, # pyrefly: ignore[bad-argument-type] + is_utf8_string=feature_schema.is_utf8_string, # pyrefly: ignore[bad-argument-type] ) if is_id: feature_ddls.append(f" {feature_name} {sql_type} NOT NULL") @@ -402,7 +402,7 @@ def node_to_spanner_row( Returns: An instance a `cls` object with the node data. """ - node_dict = features_to_dict(node.features) + node_dict = features_to_dict(node.features) # pyrefly: ignore[bad-argument-type] node_dict[id_key] = node.id return cls(**node_dict) @@ -617,10 +617,10 @@ def create_spanner_row_type_from_node_schema( column_names_and_types.append(( feature_name, feature_format_to_type_hint( - feature_schema.format, feature_schema.shape + feature_schema.format, feature_schema.shape # pyrefly: ignore[bad-argument-type] ), )) - return NamedTuple(f"{node_set_name}_{name_suffix}", column_names_and_types) + return NamedTuple(f"{node_set_name}_{name_suffix}", column_names_and_types) # pyrefly: ignore[bad-argument-count, bad-return] def create_spanner_row_type_from_edge_schema( @@ -645,13 +645,13 @@ def create_spanner_row_type_from_edge_schema( _DEFAULT_ID_KEY_GNN ] source_hint = feature_format_to_type_hint( - source_feature_id.format, source_feature_id.shape + source_feature_id.format, source_feature_id.shape # pyrefly: ignore[bad-argument-type] ) target_feature_id = schema.node_sets[edge_schema.target].features[ _DEFAULT_ID_KEY_GNN ] target_hint = feature_format_to_type_hint( - target_feature_id.format, target_feature_id.shape + target_feature_id.format, target_feature_id.shape # pyrefly: ignore[bad-argument-type] ) column_names_and_types = [("source", source_hint), ("target", target_hint)] @@ -659,10 +659,10 @@ def create_spanner_row_type_from_edge_schema( column_names_and_types.append(( feature_name, feature_format_to_type_hint( - feature_schema.format, feature_schema.shape + feature_schema.format, feature_schema.shape # pyrefly: ignore[bad-argument-type] ), )) - return NamedTuple(f"{edge_set_name}_{name_suffix}", column_names_and_types) + return NamedTuple(f"{edge_set_name}_{name_suffix}", column_names_and_types) # pyrefly: ignore[bad-argument-count, bad-return] def create_spanner_row_types_from_schema( diff --git a/dgf/src/io/tf_graph_sample.py b/dgf/src/io/tf_graph_sample.py index ed77e6e..4a79261 100644 --- a/dgf/src/io/tf_graph_sample.py +++ b/dgf/src/io/tf_graph_sample.py @@ -77,7 +77,7 @@ def _parse_feature( feature_value = example[src_feature_name] if feature_schema.is_static_shape(): - feature_value = feature_value.reshape( + feature_value = feature_value.reshape( # pyrefly: ignore[no-matching-overload] (num_items,) + (feature_schema.shape or ()) ) dst_dict[dst_feature_name] = feature_value @@ -626,7 +626,7 @@ def generator(): if isinstance(container_type, str): container_type = TFGraphSampleContainerType[container_type] paths = shard_lib.expand_input_paths(path) - path_dataset = tf.data.Dataset.from_tensor_slices(paths) + path_dataset = tf.data.Dataset.from_tensor_slices(paths) # pyrefly: ignore[bad-argument-type] # Build the tf parsing spec. feature_spec = build_tfgnn_feature_spec(schema) @@ -634,7 +634,7 @@ def generator(): if container_type == TFGraphSampleContainerType.TF_RECORD: def read_serialized_proto_dataset(path): - return tf.data.TFRecordDataset(path, compression_type=compression) + return tf.data.TFRecordDataset(path, compression_type=compression) # pyrefly: ignore[bad-instantiation] else: raise ValueError("Non supported container type") @@ -712,7 +712,7 @@ def graphs_to_serialized_tfgnn_graphs( graphs: Sequence[in_memory_graph.InMemoryGraph], schema: schema_lib.GraphSchema | None = None, *, - num_threads: int = os.cpu_count() * 2, + num_threads: int = os.cpu_count() * 2, # pyrefly: ignore[unsupported-operation] ) -> List[bytes]: """Converts a sequence of InMemoryGraphs into serialized TF-GNN graph sample protos. diff --git a/dgf/src/io/tf_graph_sample_test.py b/dgf/src/io/tf_graph_sample_test.py index 6a4d7bb..4ce3b6f 100644 --- a/dgf/src/io/tf_graph_sample_test.py +++ b/dgf/src/io/tf_graph_sample_test.py @@ -200,7 +200,7 @@ def in_mem_graphs(): graph, schema=schema ) for current_path in paths: - for tensor in tf.data.TFRecordDataset( + for tensor in tf.data.TFRecordDataset( # pyrefly: ignore[bad-instantiation] current_path, compression_type="GZIP" ): read_example = tf.train.Example.FromString(tensor.numpy()) diff --git a/dgf/src/io/tfexample.py b/dgf/src/io/tfexample.py index 9d1dad8..ec8f2be 100644 --- a/dgf/src/io/tfexample.py +++ b/dgf/src/io/tfexample.py @@ -42,7 +42,7 @@ def _read_dataset_generic( col: tf.io.VarLenFeature(dtype) for col, (dtype, _) in columns.items() } - path_dataset = tf.data.Dataset.from_tensor_slices(paths) + path_dataset = tf.data.Dataset.from_tensor_slices(paths) # pyrefly: ignore[bad-argument-type] if not paths: raise ValueError("paths should not be empty") @@ -215,7 +215,7 @@ def read_tf_record( """ def _read_tfrecord_dataset(path): - return tf.data.TFRecordDataset( + return tf.data.TFRecordDataset( # pyrefly: ignore[bad-instantiation] path, compression_type="GZIP" if compressed else "" )