# How to add if condition in a TensorFlow graph?

PythonIf StatementTensorflow## Python Problem Overview

Let's say I have following code:

```
x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")
if condition > 0:
y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
y = tf.nn.softmax(tf.matmul(x, W) - b)
```

Would the `if`

statement work in the calculation (I do not think so)? If not, how can I add an `if`

statement into the TensorFlow calculation graph?

## Python Solutions

## Solution 1 - Python

You're correct that the `if`

statement doesn't work here, because the condition is evaluated at graph construction time, whereas presumably you want the condition to depend on the value fed to the placeholder at runtime. (In fact, it will always take the first branch, because `condition > 0`

evaluates to a `Tensor`

, which is "truthy" in Python.)

To support conditional control flow, TensorFlow provides the `tf.cond()`

operator, which evaluates one of two branches, depending on a boolean condition. To show you how to use it, I'll rewrite your program so that `condition`

is a scalar `tf.int32`

value for simplicity:

```
x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")
y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)
```

## Solution 2 - Python

#### TensorFlow 2.0

TF 2.0 introduces a feature called AutoGraph which lets you JIT compile python code into Graph executions. This means you can use python control flow statements (yes, this includes `if`

statements). From the docs,

> AutoGraph supports common Python statements like `while`

, `for`

, `if`

,
> `break`

, `continue`

and `return`

, with support for nesting. That means you
> can use Tensor expressions in the condition of `while`

and `if`

> statements, or iterate over a Tensor in a `for`

loop.

You will need to define a function implementing your logic and annotate it with ** tf.function**. Here is a modified example from the documentation:

```
import tensorflow as tf
@tf.function
def sum_even(items):
s = 0
for c in items:
if tf.equal(c % 2, 0):
s += c
return s
sum_even(tf.constant([10, 12, 15, 20]))
# <tf.Tensor: id=1146, shape=(), dtype=int32, numpy=42>
```