python – Mask R-CNN is not loading weights properly for inference and re-training

QUESTION:

I’m new to the world of computer vision and this is my second project with it. I am running an edited version of the Matterport Mask RCNN that runs with tensorflow-gpu==2.7. (Found out later it would have worked out just fine with an older version) I am trying to use this with a pen data set I created.

Anyway, the problem I am having is whenever I load the trained weights into the model to resume training it, the metrics all skyrocket back up. I am also getting bad predictions loading them for inference as well. Why are my weights not loading or saving properly? I am saving the weights using callbacks and loading them using the following:

model = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)

# Get path to saved weights
model_path = model.find_last()

# Load trained weights
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)

WHAT I’VE TRIED:

I have tried saving the whole model by changing the save_weights_only in the callbacks to False. I ran into the get_config() issue in this thread and followed through on some of those solutions, but to no avail.

I have also tried messing around with image sizes and epoch number as well.

I have tried saving the model using:

from tensorflow import keras
model.keras_model.save(complete filepath)
model = keras.models.load_model('path/to/location')

which led to the same get_config() issue.

RESOURCES:

Here is a list of the things I am running:

 # ITEM ########### VERSION ##########################
 # Python         # 3.9.7                            #
 # conda          # 4.10.3                           # 
 # CUDA           # 11.4                             #
 # WindowsOS      # 11                               #
 # cuDNN          # 8.2.4                            # 
 #####################################################
 
################################### PACKAGES ##################################
# packages in environment at C:Usersecsananaconda3envsPrototype:
# Command: conda list
# Name #################### Version ################ Build # Channel ############
# absl-py                   1.0.0                    pypi_0    pypi           #
# alabaster                 0.7.12                   pypi_0    pypi           #
# argon2-cffi               21.1.0                   pypi_0    pypi           #
# astunparse                1.6.3                    pypi_0    pypi           #
# attrs                     21.2.0                   pypi_0    pypi           #
# babel                     2.9.1                    pypi_0    pypi           #
# backcall                  0.2.0                    pypi_0    pypi           #
# bleach                    4.1.0                    pypi_0    pypi           #
# ca-certificates           2021.10.8            h5b45459_0    conda-forge    # 
# cachetools                4.2.4                    pypi_0    pypi           #
# certifi                   2021.10.8                pypi_0    pypi           #
# cffi                      1.15.0                   pypi_0    pypi           #
# charset-normalizer        2.0.9                    pypi_0    pypi           #
# colorama                  0.4.4                    pypi_0    pypi           #
# console_shortcut          0.1.1                         4                   #
# cycler                    0.11.0                   pypi_0    pypi           #
# cython                    0.29.25                  pypi_0    pypi           #
# debugpy                   1.5.1                    pypi_0    pypi           #
# decorator                 5.1.0                    pypi_0    pypi           #
# defusedxml                0.7.1                    pypi_0    pypi           #
# dill                      0.3.4                    pypi_0    pypi           #
# docutils                  0.17.1                   pypi_0    pypi           #
# entrypoints               0.3                      pypi_0    pypi           #
# flatbuffers               2.0                      pypi_0    pypi           #
# fonttools                 4.28.3                   pypi_0    pypi           #
# gast                      0.4.0                    pypi_0    pypi           #
# google-auth               2.3.3                    pypi_0    pypi           #
# google-auth-oauthlib      0.4.6                    pypi_0    pypi           #
# google-pasta              0.2.0                    pypi_0    pypi           #
# grpcio                    1.42.0                   pypi_0    pypi           #
# h5py                      3.6.0                    pypi_0    pypi           #
# idna                      3.3                      pypi_0    pypi           #
# imageio                   2.13.2                   pypi_0    pypi           #
# imagesize                 1.3.0                    pypi_0    pypi           #
# imgaug                    0.4.0                    pypi_0    pypi           #
# importlib-metadata        4.8.2                    pypi_0    pypi           #
# ipykernel                 6.6.0                    pypi_0    pypi           #
# ipyparallel               8.0.0                    pypi_0    pypi           #
# ipython                   7.30.1                   pypi_0    pypi           #
# ipython-genutils          0.2.0                    pypi_0    pypi           #
# ipywidgets                7.6.5                    pypi_0    pypi           #
# jedi                      0.18.1                   pypi_0    pypi           #
# jinja2                    3.0.3                    pypi_0    pypi           #
# joblib                    1.1.0                    pypi_0    pypi           #
# jsonschema                4.2.1                    pypi_0    pypi           #
# jupyter-client            7.1.0                    pypi_0    pypi           #
# jupyter-core              4.9.1                    pypi_0    pypi           #
# jupyterlab-pygments       0.1.2                    pypi_0    pypi           #
# jupyterlab-widgets        1.0.2                    pypi_0    pypi           #
# keras                     2.7.0                    pypi_0    pypi           #
# keras-preprocessing       1.1.2                    pypi_0    pypi           #
# kiwisolver                1.3.2                    pypi_0    pypi           #
# libclang                  12.0.0                   pypi_0    pypi           #
# markdown                  3.3.6                    pypi_0    pypi           #
# markupsafe                2.0.1                    pypi_0    pypi           #
# matplotlib                3.5.0                    pypi_0    pypi           #
# matplotlib-inline         0.1.3                    pypi_0    pypi           #
# mistune                   0.8.4                    pypi_0    pypi           #
# nbclient                  0.5.9                    pypi_0    pypi           #
# nbconvert                 6.3.0                    pypi_0    pypi           #
# nbformat                  5.1.3                    pypi_0    pypi           #
# nest-asyncio              1.5.4                    pypi_0                   #
# networkx                  2.6.3                    pypi_0    pypi           #
# nose                      1.3.7                    pypi_0    pypi           #
# notebook                  6.4.6                    pypi_0    pypi           #
# numpy                     1.19.5                   pypi_0    pypi           #
# oauthlib                  3.1.1                    pypi_0    pypi           #
# opencv-python             4.5.4.60                 pypi_0    pypi           #
# openssl                   3.0.0                h8ffe710_2    conda-forge    #
# opt-einsum                3.3.0                    pypi_0    pypi           #
# packaging                 21.3                     pypi_0    pypi           #
# pandocfilters             1.5.0                    pypi_0    pypi           #
# parso                     0.8.3                    pypi_0    pypi           #
# pickleshare               0.7.5                    pypi_0    pypi           #
# pillow                    8.4.0                    pypi_0    pypi           #
# pip                       21.3.1             pyhd8ed1ab_0    conda-forge    #
# prometheus-client         0.12.0                   pypi_0    pypi           #
# prompt-toolkit            3.0.23                   pypi_0    pypi           #
# protobuf                  3.19.1                   pypi_0    pypi           #
# psutil                    5.8.0                    pypi_0    pypi           #
# pyasn1                    0.4.8                    pypi_0    pypi           #
# pyasn1-modules            0.2.8                    pypi_0    pypi           #
# pycparser                 2.21                     pypi_0    pypi           #
# pygments                  2.10.0                   pypi_0    pypi           #
# pyparsing                 3.0.6                    pypi_0    pypi           #
# pyrsistent                0.18.0                   pypi_0    pypi           #
# python                    3.9.7        h900ac77_3_cpython    conda-forge    #
# python-dateutil           2.8.2                    pypi_0    pypi           #
# python_abi                3.9                      2_cp39    conda-forge    #
# pytz                      2021.3                   pypi_0    pypi           #
# pywavelets                1.2.0                    pypi_0    pypi           #
# pywin32                   302                      pypi_0    pypi           #
# pywinpty                  1.1.6                    pypi_0    pypi           #
# pyzmq                     22.3.0                   pypi_0    pypi           #
# qtconsole                 5.2.1                    pypi_0    pypi           #
# qtpy                      1.11.3                   pypi_0    pypi           #
# requests                  2.26.0                   pypi_0    pypi           #
# requests-oauthlib         1.3.0                    pypi_0    pypi           #
# rsa                       4.8                      pypi_0    pypi           #
# scikit-image              0.18.3                   pypi_0    pypi           #
# scipy                     1.7.3                    pypi_0    pypi           #
# send2trash                1.8.0                    pypi_0    pypi           #
# setuptools                59.4.0           py39hcbf5309_0    conda-forge    #
# setuptools-scm            6.3.2                    pypi_0    pypi           #
# shapely                   1.8.0                    pypi_0    pypi           #
# six                       1.15.0                   pypi_0    pypi           #
# snowballstemmer           2.2.0                    pypi_0    pypi           #
# sphinx                    4.3.1                    pypi_0    pypi           #
# sphinxcontrib-applehelp   1.0.2                    pypi_0    pypi           #
# sphinxcontrib-devhelp     1.0.2                    pypi_0    pypi           #
# sphinxcontrib-htmlhelp    2.0.0                    pypi_0    pypi           #
# sphinxcontrib-jsmath      1.0.1                    pypi_0    pypi           #
# sphinxcontrib-qthelp      1.0.3                    pypi_0    pypi           #
# sphinxcontrib-serializinghtml 1.1.5                pypi_0    pypi           #
# sqlite                    3.37.0               h8ffe710_0    conda-forge    #
# tb-nightly                2.8.0a20211220           pypi_0    pypi           #
# tensorboard               2.7.0                    pypi_0    pypi           #
# tensorboard-data-server   0.6.1                    pypi_0    pypi           #
# tensorboard-plugin-wit    1.8.0                    pypi_0    pypi           #
# tensorflow-estimator      2.7.0                    pypi_0    pypi           #
# tensorflow-gpu            2.7.0                    pypi_0    pypi           #
# tensorflow-io-gcs-filesystem 0.23.1                pypi_0    pypi           #
# termcolor                 1.1.0                    pypi_0    pypi           #
# terminado                 0.12.1                   pypi_0    pypi           #
# testpath                  0.5.0                    pypi_0    pypi           #
# tf-estimator-nightly      2.8.0.dev2021122009      pypi_0    pypi           #
# tifffile                  2021.11.2                pypi_0    pypi           #
# tomli                     1.2.2                    pypi_0    pypi           #
# tornado                   6.1                      pypi_0    pypi           #
# tqdm                      4.62.3                   pypi_0    pypi           #
# traitlets                 5.1.1                    pypi_0    pypi           #
# typing-extensions         4.0.1                    pypi_0    pypi           #
# tzdata                    2021e                he74cb21_0    conda-forge    #
# ucrt                      10.0.20348.0         h57928b3_0    conda-forge    #
# urllib3                   1.26.7                   pypi_0    pypi           #
# vc                        14.2                 hb210afc_5    conda-forge    #
# vs2015_runtime            14.29.30037          h902a5da_5    conda-forge    #
# wcwidth                   0.2.5                    pypi_0    pypi           #
# webencodings              0.5.1                    pypi_0    pypi           #
# werkzeug                  2.0.2                    pypi_0    pypi           #
# wheel                     0.37.0             pyhd8ed1ab_1    conda-forge    #
# widgetsnbextension        3.5.2                    pypi_0    pypi           #
# wrapt                     1.13.3                   pypi_0    pypi           #
# zipp                      3.6.0                    pypi_0    pypi           #
###############################################################################

Here is a link to my tensorboard and an example of a bad prediction:

You should see the model learning and then a spike at the end, that spike was when I loaded the weights again and resumed training.

https://tensorboard.dev/experiment/KkgugOP7RGu12lVCA6M29Q/

Bad Prediction

Here is my custom config for training:

class CustomConfig(Config):
    """Configuration for training on the toy shapes dataset.
    Derives from the base Config class and overrides values specific
    to the toy shapes dataset.
    """
    """Configuration for training on the dataset.
    Derives from the base Config class and overrides some values.
    """    


    DETECTION_MIN_CONFIDENCE = 0.7 # Skip detections with < 90% confidence
    # Give the configuration a recognizable name
    NAME = "PEN"

    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 8

    # Number of classes (including background)
    NUM_CLASSES = 1 + 1  # background + PEN

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128

    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels

    # Reduce training ROIs per image because the images are small and have
    # few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
    TRAIN_ROIS_PER_IMAGE = 32

    # Use a small epoch since the data is simple
    STEPS_PER_EPOCH = 300

    # use small validation steps since the epoch is small
    VALIDATION_STEPS = 10
    
config = CustomConfig()
config.display()

Here is my inference config:

class InferenceConfig(CustomConfig):
    NAME = "PEN"

    NUM_CLASSES = 1 + 1  # background + PEN

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128

    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    DETECTION_MIN_CONFIDENCE = 0.9

If you need additional information please let me know. This is also my first post and any guidance is appreciated.

Leave a Comment