python – How to use K-fold cross validation on transfer learning?

I have created a transfer learning model using Resnet50. I want to perform K-fold cross-validation on my model after which I want to find the average AUC value and standard deviation. However, I am getting an error message while performing the task. I have created a separate Files.csv file which contains the image names and their corresponding labels. I am not sure if this is the correct method or not. Please let me know if there is any other process. Please find my code below:

from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras import Model, layers
from tensorflow.keras.models import load_model, model_from_json
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Dense, Input
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import KFold, StratifiedKFold
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_data = pd.read_csv('Files.csv')
Y = train_data[['label']]
kf = KFold(n_splits = 5)
from tensorflow.keras.preprocessing.image import ImageDataGenerator

idg = ImageDataGenerator(rescale = 1./255,
                                  
                                   horizontal_flip=True, 
                                
                                   rotation_range=40,
                                   zoom_range= 0.2,
                                   
                                   shear_range=0.2,
                                   
                                   width_shift_range=0.2,
                                   height_shift_range=0.2,)
validation_datagen = ImageDataGenerator(rescale = 1./255)
def get_model_name(k):
    return 'model_'+str(k)+'.h5'
from keras import models
from keras.layers import Dense, Flatten
from tensorflow.keras import optimizers
from keras.applications.vgg16 import VGG16
from tensorflow.keras.applications import ResNet50
image_dir=r'D:/regionGrowing_MLT/NewSavedRGBImages/Training'
VALIDATION_ACCURACY = []
VALIDAITON_LOSS = []

save_dir="C:/Users/warid"
fold_var = 1

for train_index, val_index in kf.split(np.zeros(n),Y):
        training_data = train_data.iloc[train_index]
        validation_data = train_data.iloc[val_index]

        train_data_generator = idg.flow_from_dataframe(training_data, directory = image_dir,
                                   x_col = "filename", y_col = "label",
                                   class_mode = "categorical", shuffle = True)
        valid_data_generator  = idg.flow_from_dataframe(validation_data, directory = image_dir,
                                x_col = "filename", y_col = "label",
                                class_mode = "categorical", shuffle = True)

        # CREATE NEW MODEL
        model = models.Sequential()
        model.add(ResNet50(weights="imagenet", include_top=False, input_shape=(224,224,3)))
        model.add(Flatten())
        model.add(ChannelAttention(32, 8))
        model.add(SpatialAttention(7))
        model.add(Dense(256, activation='relu', name="fc1"))
        model.add(Dense(128, activation='relu', name="fc2"))
        model.add(layers.Dropout(0.5))  #### used for regularization (to aviod overfitting)
        model.add(Dense(2, activation='sigmoid'))
    # model.summary()

        model.compile(optimizer=optimizers.Adam(learning_rate=2e-5),
                  loss="binary_crossentropy",
                  metrics=['accuracy'])
    # COMPILE NEW MODEL

    # CREATE CALLBACKS
        checkpoint = tf.keras.callbacks.ModelCheckpoint(save_dir+get_model_name(fold_var), 
                                monitor="val_accuracy", verbose=1, 
                                save_best_only=True, mode="max")
        callbacks_list = [checkpoint]
        # There can be other callbacks, but just showing one because it involves the model name
        # This saves the best model
        # FIT THE MODEL
        history = model.fit(train_data_generator,
                    epochs=num_epochs,
                    callbacks=callbacks_list,
                    validation_data=valid_data_generator)
        #PLOT HISTORY
        #       :
        #       :

        # LOAD BEST MODEL to evaluate the performance of the model
        model.load_weights("/saved_models/model_"+str(fold_var)+".h5")

        results = model.evaluate(valid_data_generator)
        results = dict(zip(model.metrics_names,results))

        VALIDATION_ACCURACY.append(results['accuracy'])
        VALIDATION_LOSS.append(results['loss'])

        tf.keras.backend.clear_session()

        fold_var += 1

After running this code, I am getting the following error message:

Found 3076 validated image filenames belonging to 2 classes.
Found 769 validated image filenames belonging to 1 classes.
Epoch 1/5
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
Input In [75], in <cell line: 7>()
     39 callbacks_list = [checkpoint]
     40 # There can be other callbacks, but just showing one because it involves the model name
     41 # This saves the best model
     42 # FIT THE MODEL
---> 43 history = model.fit(train_data_generator,
     44             epochs=num_epochs,
     45             callbacks=callbacks_list,
     46             validation_data=valid_data_generator)
     47 #PLOT HISTORY
     48 #       :
     49 #       :
     50 
     51 # LOAD BEST MODEL to evaluate the performance of the model
     52 model.load_weights("/saved_models/model_"+str(fold_var)+".h5")

File ~anaconda3libsite-packageskerasutilstraceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     65 except Exception as e:  # pylint: disable=broad-except
     66   filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67   raise e.with_traceback(filtered_tb) from None
     68 finally:
     69   del filtered_tb

File ~anaconda3libsite-packagestensorflowpythoneagerexecute.py:54, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     52 try:
     53   ctx.ensure_initialized()
---> 54   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     55                                       inputs, attrs, num_outputs)
     56 except core._NotOkStatusException as e:
     57   if name is not None:

InvalidArgumentError: Graph execution error:
Detected at node 'sequential_2/flatten_2/Reshape' defined at (most recent call last):
    File "C:Userswaridanaconda3librunpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:Userswaridanaconda3librunpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:Userswaridanaconda3libsite-packagesipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "C:Userswaridanaconda3libsite-packagestraitletsconfigapplication.py", line 846, in launch_instance
      app.start()
    File "C:Userswaridanaconda3libsite-packagesipykernelkernelapp.py", line 677, in start
      self.io_loop.start()
    File "C:Userswaridanaconda3libsite-packagestornadoplatformasyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "C:Userswaridanaconda3libasynciobase_events.py", line 601, in run_forever
      self._run_once()
    File "C:Userswaridanaconda3libasynciobase_events.py", line 1905, in _run_once
      handle._run()
    File "C:Userswaridanaconda3libasyncioevents.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:Userswaridanaconda3libsite-packagesipykernelkernelbase.py", line 471, in dispatch_queue
      await self.process_one()
    File "C:Userswaridanaconda3libsite-packagesipykernelkernelbase.py", line 460, in process_one
      await dispatch(*args)
    File "C:Userswaridanaconda3libsite-packagesipykernelkernelbase.py", line 367, in dispatch_shell
      await result
    File "C:Userswaridanaconda3libsite-packagesipykernelkernelbase.py", line 662, in execute_request
      reply_content = await reply_content
    File "C:Userswaridanaconda3libsite-packagesipykernelipkernel.py", line 360, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "C:Userswaridanaconda3libsite-packagesipykernelzmqshell.py", line 532, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packagesIPythoncoreinteractiveshell.py", line 2863, in run_cell
      result = self._run_cell(
    File "C:Userswaridanaconda3libsite-packagesIPythoncoreinteractiveshell.py", line 2909, in _run_cell
      return runner(coro)
    File "C:Userswaridanaconda3libsite-packagesIPythoncoreasync_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:Userswaridanaconda3libsite-packagesIPythoncoreinteractiveshell.py", line 3106, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:Userswaridanaconda3libsite-packagesIPythoncoreinteractiveshell.py", line 3309, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:Userswaridanaconda3libsite-packagesIPythoncoreinteractiveshell.py", line 3369, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:UserswaridAppDataLocalTempipykernel_370762928028949.py", line 43, in <cell line: 7>
      history = model.fit(train_data_generator,
    File "C:Userswaridanaconda3libsite-packageskerasutilstraceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasenginetraining.py", line 1409, in fit
      tmp_logs = self.train_function(iterator)
    File "C:Userswaridanaconda3libsite-packageskerasenginetraining.py", line 1051, in train_function
      return step_function(self, iterator)
    File "C:Userswaridanaconda3libsite-packageskerasenginetraining.py", line 1040, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:Userswaridanaconda3libsite-packageskerasenginetraining.py", line 1030, in run_step
      outputs = model.train_step(data)
    File "C:Userswaridanaconda3libsite-packageskerasenginetraining.py", line 889, in train_step
      y_pred = self(x, training=True)
    File "C:Userswaridanaconda3libsite-packageskerasutilstraceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasenginetraining.py", line 490, in __call__
      return super().__call__(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasutilstraceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasenginebase_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasutilstraceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasenginesequential.py", line 374, in call
      return super(Sequential, self).call(inputs, training=training, mask=mask)
    File "C:Userswaridanaconda3libsite-packageskerasenginefunctional.py", line 458, in call
      return self._run_internal_graph(
    File "C:Userswaridanaconda3libsite-packageskerasenginefunctional.py", line 596, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasutilstraceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasenginebase_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskerasutilstraceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "C:Userswaridanaconda3libsite-packageskeraslayersreshapingflatten.py", line 98, in call
      return tf.reshape(inputs, flattened_shape)
Node: 'sequential_2/flatten_2/Reshape'
Input to reshape is a tensor with 4194304 values, but the requested shape requires a multiple of 100352
     [[{{node sequential_2/flatten_2/Reshape}}]] [Op:__inference_train_function_37893]

Leave a Comment