Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Terrible slow caused by ComplexBatchNormalization() #30

Open
King-Of-Knights opened this issue Jun 12, 2022 · 4 comments
Open

Terrible slow caused by ComplexBatchNormalization() #30

King-Of-Knights opened this issue Jun 12, 2022 · 4 comments

Comments

@King-Of-Knights
Copy link

Hi there, @NEGU93. Thanks for the great effort in making this library. It really accelerate my research in signal recognition task. This TF 2.0 version indeed help me deploy in the edge device with the help of TFlite. However, I found ComplexBatchNormalization() will terribly slow down the training process. Give one example to reproduce:

import numpy as np
from tensorflow.keras.models import Model
import tensorflow
from cvnn.layers import ComplexConv1D, ComplexInput, ComplexDense, ComplexBatchNormalization, ComplexFlatten, complex_input

X_train = np.random.rand(18000, 4096, 2)
Y_train = np.random.randint(0, 9, 18000)
X_test = np.random.rand(2000, 4096, 2)
Y_test = np.random.randint(0, 9, 2000)

inputs = complex_input(shape=X_train.shape[1:])
outs = inputs
outs = (ComplexConv1D(16, 6, strides=1, padding='same', activation='cart_relu'))(outs)
outs = (ComplexBatchNormalization())(outs)

outs = (ComplexConv1D(32, 3, strides=1, padding='same', activation='cart_relu'))(outs)
outs = (ComplexBatchNormalization())(outs)

outs = (ComplexFlatten())(outs)
DL_feature = (ComplexDense(128, activation='cart_relu'))(outs)
outs = (ComplexDense(256, activation='cart_relu'))(DL_feature)
outs = (ComplexDense(256, activation='cart_relu'))(outs)
predictions = (ComplexDense(, activation='cast_to_real'))(outs)

model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer=tensorflow.keras.optimizers.Adam(learning_rate=1e-4),
              loss=tensorflow.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(X_train, Y_train, batch_size=32, epochs=3, verbose=1, validation_data=(X_test, Y_test),
                    callbacks=[checkpoint, earlystopping, learn_rate])

It almost cost me 10 mins to train one epoch. But, when I substitute ComplexBatchNormalization() to BatchNormalization(), it only costs me half min. Any ideas?

@NEGU93
Copy link
Owner

NEGU93 commented Sep 22, 2022

Indeed, the Complex BatchNorm is not optimized and is not previewed to be optimized in the short term. I am sorry for the trouble caused. The reason is similar as what happens with ComplexPyTorch.

@jollyjonson
Copy link

I was having the same problem and came up with this simple solution. According to the authors of ComplexPyTorch performing batch nomalization in a 'naive' way i.e. separately on the real and imaginary parts does not have a significant impact when compared to the complex formulation of Trabelsi et al.

Here's a TF version of their NaiveComplexBatchNorm layer, which can be used with the keras functional API.

import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization

def naive_complex_batch_normalization(inputs: tf.Tensor) -> tf.Tensor:
    real, imag = tf.cast(tf.math.real(inputs), tf.float32), tf.cast(tf.math.imag(inputs), tf.float32)
    real_bn, imag_bn = BatchNormalization()(real), BatchNormalization()(imag)
    return tf.cast(tf.complex(real_bn, imag_bn), tf.complex64)

@NEGU93, would you be interested in a PR implementing this as a proper tf.keras.layers.Layerclass?

@NEGU93
Copy link
Owner

NEGU93 commented Oct 5, 2022

Sure, not sure what they are based on to guarantee that, from my point of view, doing a naive implementation may have a very negative impact on the phase, which is a crucial aspect of CVNN merits Ref.
But well, using CReLU should have a similar impact, and it still works well, so... Why not?

Please, submit your PR! and thank you for the contribution!

@maorgranot1
Copy link

Here is an implementation of a small 1D CNN for example until that PR would be integrated into the cvnn package:

def get_model(input_len=1000, activation_func='crelu'):
inputs = layers.complex_input(shape=(input_len, 1))
conv0 = layers.ComplexConv1D(64, 7, input_shape=(input_len, 1), activation=activation_func)(inputs)
bn_r0 = keras.layers.BatchNormalization()(tf.cast(tf.math.real(conv0), tf.float32))
bn_i0 = keras.layers.BatchNormalization()(tf.cast(tf.math.imag(conv0), tf.float32))
p0 = layers.ComplexAvgPooling1D(pool_size=2)(tf.cast(tf.complex(bn_r0, bn_i0), tf.complex64))
out = layers.ComplexConv1D(32, 3, activation=activation_func)(p0)
return tf.keras.Model(inputs, out)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants