当前位置:  开发笔记 > 编程语言 > 正文

如何在1D Tensor中查找重复元素

如何解决《如何在1DTensor中查找重复元素》经验,为你挑选了1个好方法。

我想获得在1D张量中出现多次的元素.确切地说,我想创建一个与之相反的功能tf.unique.例如,如果x = [1, 1, 2, 3, 4, 5, 6, 7, 4, 5, 4]我需要输出,[1,1,4,4,4,5,5]同时还检索源张量中的那些元素的索引.我的最终目标是在批处理中获取标签出现多次的示例.



1> dga..:

您可以使用现有的Tensorflow操作以略微圆整的方式执行此操作,方法是计算唯一项目以创建唯一项目的密集索引集合,然后使用它们进行计数tf.unsorted_segment_sum.一旦你的计数,选择的项目> N使用tf.greater,并收集起来,放回密列表:

import tensorflow as tf

a = tf.constant([8, 7, 8, 1, 3, 4, 5, 9, 5, 0, 5])
init = tf.initialize_all_variables()

unique_a_vals, unique_idx = tf.unique(a)
count_a_unique = tf.unsorted_segment_sum(tf.ones_like(a),                   
                                         unique_idx,                        
                                         tf.shape(a)[0])                    

more_than_one = tf.greater(count_a_unique, 1)                               
more_than_one_idx = tf.squeeze(tf.where(more_than_one))                     
more_than_one_vals = tf.squeeze(tf.gather(unique_a_vals, more_than_one_idx))

# If you want the original indexes:                                         
not_duplicated, _ = tf.listdiff(a, more_than_one_vals)                      
dups_in_a, indexes_in_a = tf.listdiff(a, not_duplicated)                    

with tf.Session() as s:                                                     
    s.run(init)                                                             
    a, dupvals, dupidxes, dia = s.run([a, more_than_one_vals,                    
                                  indexes_in_a, dups_in_a])                            
    print "Input: ", a                                                      
    print "Duplicate values: ", dupvals                                     
    print "Indexes of duplicates in a: ", dupidxes
    print "Dup vals with dups: ", dia

输入:[8 7 8 1 3 4 5 9 5 0 5]

重复值:[8 5]

重复索引:a [0 2 6 8 10]

带有重复的复式:[8 8 5 5 5]

推荐阅读
携手相约幸福
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有