diff --git a/tensorflow_addons/utils/ensure_tf_install.py b/tensorflow_addons/utils/ensure_tf_install.py index 4351afbc7a..e5dcae52fe 100644 --- a/tensorflow_addons/utils/ensure_tf_install.py +++ b/tensorflow_addons/utils/ensure_tf_install.py @@ -17,27 +17,40 @@ # Ensure TensorFlow is importable and its version is sufficiently recent. This # needs to happen before anything else, since the imports below will try to # import tensorflow, too. + +from distutils.version import LooseVersion +import warnings + +import tensorflow as tf + + +warning_template = """ +This version of TensorFlow Addons requires TensorFlow {required}. +Detected an installation of version {present}. + +While some functions might work, TensorFlow Addons was not tested +with this TensorFlow version. Also custom ops were not compiled +against this version of TensorFlow. If you use custom ops, +you might get errors (segmentation faults for example). + +It might help you to fallback to pure Python ops with +TF_ADDONS_PY_OPS . To do that, see +https://github.com/tensorflow/addons#gpucpu-custom-ops + +If you encounter errors, do *not* file bugs in GitHub because +the version of TensorFlow you are using is not supported. +""" + + def _ensure_tf_install(): - """Attempt to import tensorflow, and ensure its version is sufficient. - Raises: - ImportError: if either tensorflow is not importable or its version is - inadequate. + """Warn the user if the version of TensorFlow used is not supported. """ - import tensorflow as tf - import distutils.version - # # Update this whenever we need to depend on a newer TensorFlow release. - # - required_tensorflow_version = "2.1.0" - - if distutils.version.LooseVersion(tf.__version__) < distutils.version.LooseVersion( - required_tensorflow_version - ): - raise ImportError( - "This version of TensorFlow Addons requires TensorFlow " - "version >= {required}; Detected an installation of version " - "{present}. Please upgrade TensorFlow to proceed.".format( - required=required_tensorflow_version, present=tf.__version__ - ) + required_tf_version = "2.1.0" + + if LooseVersion(tf.__version__) != LooseVersion(required_tf_version): + message = warning_template.format( + required=required_tf_version, present=tf.__version__ ) + warnings.warn(message, UserWarning)