Tensorflow map

    xiaoxiao2022-07-05  175

    Tensorflow map_fn

    flyfish

    import numpy as np import tensorflow as tf elems = np.array([1, 2, 3, 4, 5, 6]) squares = tf.map_fn(lambda x: x * x, elems) sess = tf.InteractiveSession() print(squares.eval()) elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) alternate = tf.map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) # alternate == [-1, 2, -3] print(alternate.eval()) elems = np.array([1, 2, 3]) alternates = tf.map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) # alternates[0] == [1, 2, 3] # alternates[1] == [-1, -2, -3] print(alternates[0].eval()) print(alternates[1].eval()) #[ 1 4 9 16 25 36] #[-1 2 -3] #[1 2 3] #[-1 -2 -3]
    最新回复(0)