pythonargmaxaxis1_np.argmax(input,axis)和tf.argmax(input,axis)使用

 2023-09-18 阅读 27 评论 0

摘要:np.argmax(input,axis)和tf.argmax(input,axis)分别是numpy和TensorFlow底下的求最大值索引的方法,用法基本一致,只有默认情况下有细微差别,以及传入的值略有不同,分别是array和tensor。说白了,是不同模块下的相同方法。。只是不同模块下&

np.argmax(input,axis)和tf.argmax(input,axis)分别是numpy和TensorFlow底下的求最大值索引的方法,用法基本一致,只有默认情况下有细微差别,以及传入的值略有不同,分别是array和tensor。

说白了,是不同模块下的相同方法。。只是不同模块下,数据类型不一致而已。。

一、np.argmax(input,axis)的使用

tf.argmax(input,axis),根据axis取值的不同返回每行或者每列(在axis上比较)最大值的索引。

1.数组长度一致时,2维数组:

test = np.array([ [1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])

print(np.argmax(test))

np.argmax(test, 0)

np.argmax(test, 1)

#输出:

9

[3, 3, 1]

[2, 2, 0, 0]

axis = 0:列最大索引

axis=0时,比较每一列元素,记录每一列最大元素所在的索引,最后输出每一列最大元素所在的索引数组。

test[0] = array([1, 2, 3])

test[1] = array([2, 3, 4])

test[2] = array([5, 4, 3])

test[3] = array([8, 7, 2])

# output   :       [3, 3, 1]

axis = 1:行最大索引

axis=1时,比较每一行元素,记录每一行最大元素所在的索引,最后返回每一行最大元素所在的索引数组。

test[0] = array([1, 2, 3])  #2

test[1] = array([2, 3, 4])  #2

test[2] = array([5, 4, 3])  #0

test[3] = array([8, 7, 2])  #0

2.数组长度一致时,n维数组:数组的shape很重要!

test = np.array([

[[19,

2, 3],

[2, 21, 2]],

[[5, 4, 3],

[1, 2, 3]],

[[5, 4, 6],

[1, 2, 3]],

[[15, 14, 13],

[11, 12, 3]]

])

# 本例中,

# test形状是4*2*3,这个特别重要,axis=0,就是4个同一位置的元素比较,axis=1就是2个元素比较,axis=2就是3个元素比较。

# 再举个例子,

# test形状是3*7*5*10,这个特别重要,axis=0,就是3个同一位置的元素比较,axis=1就是7个元素比较,axis=2就是5个元素比较,axis=3就是10个元素比较。

axis=None或省略

print(np.argmax(test, axis=None))

#输出:4

# axis=None和省略结果相同,直接当成一维数组来查,21最大,是第5个元素,从0开始,对应的下标是4。

axis=0:

print(np.argmax(test, 0))

#输出:

[[0 3 3]

[3 0 1]]

# axis = 0,其实是在第0维,也就是shape的第一个数4对应的那一维,比较4个元素的值。输出的是shape除了4之外的2*3的数组。

# 本例中,第一个元素0是4个元素19,5,5,15比较时max值19对应的索引,第二个元素3是4个元素2,4,4,14比较时max值14对应的索引……

axis = 1:

print(np.argmax(test, 1))

#输出:

[[0 1 0]

[0 0 0]

[0 0 0]

[0 0 0]]

# axis = 1,其实是在第1维,也就是shape的第二个数2对应的那一维,比较2个元素的值。输出的是shape除了2之外的4*3的数组。

# 本例中,第一个元素0是2个元素19,2比较时max值19对应的索引,第二个元素1是2个元素2,21比较时max值21对应的索引,第三个元素0是2个元素3,2比较时max值3对应的索引……

axis = 2:

print(np.argmax(test, 2))

#输出:

[[0 1]

[0 2]

[2 2]

[0 1]]

# axis = 2,其实是在第2维,也就是shape的第三个数3对应的那一维,比较3个元素的值。输出的是shape除了3之外的4*2的数组。

# 本例中,第一个元素0是3个元素19,2,3比较时max值19对应的索引,第二个元素1是3个元素2,21,2比较时max值21对应的索引,第三个元素0是3个元素5, 4, 3比较时max值5对应的索引,第四个元素2是3个元素1, 2, 3比较时max值3对应的索引……

3.数组长度不一致时:

axis最大值为数组维数-1,超过则报错。参考n维数组的例子,就是在每一个axis上比较的,很明显超过维度没有意义。

不一致时,axis=0的比较也就变成了每个数组的和的比较。【这个不理解,有问题?】

二、tf.argmax(input,axis)的使用

test = tf.Variable([

[1, 2, 3],

[2, 13, 4],

[5, 4, 3],

[1, 2, 7]])

print(tf.argmax(test))  # 这个与np.argmax不同,默认axis=None或省略,与axis=0的结果相同。

print(tf.argmax(test, 0))  # 与np.argmax相同,tensor形式

print(tf.argmax(test, 1)) # 与np.argmax相同,tensor形式

# 输出:

tf.Tensor([2 1 3], shape=(3,), dtype=int64)

tf.Tensor([2 1 3], shape=(3,), dtype=int64)

tf.Tensor([2 1 0 2], shape=(4,), dtype=int64)

参考:

版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。

原文链接:https://hbdhgg.com/1/77025.html

发表评论:

本站为非赢利网站,部分文章来源或改编自互联网及其他公众平台,主要目的在于分享信息,版权归原作者所有,内容仅供读者参考,如有侵权请联系我们删除!

Copyright © 2022 匯編語言學習筆記 Inc. 保留所有权利。

底部版权信息