Skip to content

A problem in the layers.index_matrix_to_pairs function #10

@Chen-Wang-CUHK

Description

@Chen-Wang-CUHK

Hi,
Thanks for your code sharing.
But I found there are some problems in layers.index_matrix_to_pairs function:
By using the original code, I can't get
[[[0, 3], [1, 1], [2, 2]],
[[0, 2], [1, 3], [2, 1]]]
while inputing [[3,1,2], [2,3,1]].

It works, after I change the code to:

def index_matrix_to_pairs(index_matrix):
# [[3,1,2], [2,3,1]] -> [[[0, 3], [1, 1], [2, 2]],
# [[0, 2], [1, 3], [2, 1]]]
replicated_first_indices = tf.range(tf.shape(index_matrix)[1])
rank = len(index_matrix.get_shape())
if rank == 2:
replicated_first_indices = tf.tile(
tf.expand_dims(replicated_first_indices, dim=0),
[tf.shape(index_matrix)[0], 1]
)
return tf.stack([replicated_first_indices, index_matrix], axis=rank)

I don't know either the given example or the code is wrong.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions