Skip to content Skip to sidebar Skip to footer

Training Different Branches Of Model Network With Tf.switch_case

I want to create a neural network in which different branches of the network are trained depending on the t_input. So the t_input can be either 0 or 1 and depending on that only th

Solution 1:

I got your code working by changing your tf.switch_case to a keras switch, and by inputting the two separate models in (you only input one of them in your code) Note that I had to tile your t_test input because it expects the two inputs to have the same batch dimension. I am also not sure that you want np.random.binomial because this samples from the binomial distribution and will almost never return 0. You should probably look at np.random.randint and limit it to values of 0 or 1.

import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
import tensorflow.keras as K
import numpy as np

x_test = np.random.uniform(size=(10, 10))
t_test = np.array([np.random.binomial(100, 0.5)])

t_input = Input(shape=(1,), dtype=tf.int32, name="t_input")
x_input = Input(shape=(x_test.shape[1],), name='x_input')

x = Dense(32)(x_input)

x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)

x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)

r = K.backend.switch(t_input,x1,x2)

model = tf.keras.models.Model(inputs=[t_input,x_input], outputs=r)

print(model.predict([np.tile(t_test,10),x_test]))

Solution 2:

I found a way to support more than 2 branches this way :

import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
import tensorflow.keras as K
import numpy as np

def h_chooser(inpus):
  t_inp, h_s = inpus
  h_out = tf.squeeze(tf.concat(h_s, axis=1))
  t_inds = tf.stack([tf.range(tf.size(t_inp)), tf.squeeze(t_inp)], axis=1)
  h_res = tf.gather_nd(h_out, t_inds)
  return h_res

x_test = np.random.uniform(size=(10, 10))
t_test = np.array([np.random.binomial(100, 0.5)])

t_input = Input(shape=(1,), dtype=tf.int32, name="t_input")
x_input = Input(shape=(x_test.shape[1],), name='x_input')

x = Dense(32)(x_input)

x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)

x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)

x3 = Dense(16)(x)
x3 = Dense(8)(x3)
x3 = Dense(1)(x3)

h_switch_case = Lambda(lambda x: h_chooser(x))

r = h_switch_case([t_input, [x1, x2, x3])

model = tf.keras.models.Model(inputs=[t_input,x_input], outputs=r)

print(model.predict([np.tile(t_test,10),x_test]))

Post a Comment for "Training Different Branches Of Model Network With Tf.switch_case"