Blob Blame History Raw
From e6928c405fe1647a0ee6ea32543b18ebc3828227 Mon Sep 17 00:00:00 2001
From: Mattias Ellert <mattias.ellert@physics.uu.se>
Date: Thu, 23 Feb 2023 22:10:46 +0100
Subject: [PATCH] Use consistent wording of the comments in the different TMVA
 classification tests to avoid confusion.

---
 tutorials/tmva/TMVA_CNN_Classification.C  | 2 +-
 tutorials/tmva/TMVA_CNN_Classification.py | 4 ++--
 tutorials/tmva/TMVA_RNN_Classification.C  | 6 ++++--
 tutorials/tmva/TMVA_RNN_Classification.py | 5 +++--
 4 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/tutorials/tmva/TMVA_CNN_Classification.C b/tutorials/tmva/TMVA_CNN_Classification.C
index cc6829a24d..db5af784d4 100644
--- a/tutorials/tmva/TMVA_CNN_Classification.C
+++ b/tutorials/tmva/TMVA_CNN_Classification.C
@@ -134,7 +134,7 @@ void TMVA_CNN_Classification(int nevts = 1000, std::vector<bool> opt = {1, 1, 1,
 
    bool writeOutputFile = true;
 
-   int num_threads = 4;  // use by default 4 threads if value is not set before
+   int num_threads = 4; // use max 4 threads
    // switch off MT in OpenBLAS to avoid conflict with tbb
    gSystem->Setenv("OMP_NUM_THREADS", "1");
 
diff --git a/tutorials/tmva/TMVA_CNN_Classification.py b/tutorials/tmva/TMVA_CNN_Classification.py
index 85fde20763..5d3c6e7220 100644
--- a/tutorials/tmva/TMVA_CNN_Classification.py
+++ b/tutorials/tmva/TMVA_CNN_Classification.py
@@ -25,7 +25,7 @@
 
 import ROOT
 
-#switch off MT in OpenMP (BLAS)
+# switch off MT in OpenBLAS to avoid conflict with tbb
 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")
 
 TMVA = ROOT.TMVA
@@ -145,7 +145,7 @@ if not useTMVACNN:
 
 writeOutputFile = True
 
-num_threads = 4  # use default threads
+num_threads = 4  # use max 4 threads
 max_epochs = 10  # maximum number of epochs used for training
 
 
diff --git a/tutorials/tmva/TMVA_RNN_Classification.C b/tutorials/tmva/TMVA_RNN_Classification.C
index 8701ba6c19..d4bb7a4e4a 100644
--- a/tutorials/tmva/TMVA_RNN_Classification.C
+++ b/tutorials/tmva/TMVA_RNN_Classification.C
@@ -191,8 +191,10 @@ void TMVA_RNN_Classification(int nevts = 2000, int use_type = 1)
    useKeras = false;
 #endif
 
-   int num_threads = 4;   // use by default all threads
-   gSystem->Setenv("OMP_NUM_THREADS", "1"); // switch off MT in OpenBLAS
+   int num_threads = 4; // use max 4 threads
+   // switch off MT in OpenBLAS to avoid conflict with tbb
+   gSystem->Setenv("OMP_NUM_THREADS", "1");
+
    // do enable MT running
    if (num_threads >= 0) {
       ROOT::EnableImplicitMT(num_threads);
diff --git a/tutorials/tmva/TMVA_RNN_Classification.py b/tutorials/tmva/TMVA_RNN_Classification.py
index 625a120b3d..65aa83c070 100644
--- a/tutorials/tmva/TMVA_RNN_Classification.py
+++ b/tutorials/tmva/TMVA_RNN_Classification.py
@@ -26,10 +26,11 @@ num_threads = 4  # use max 4 threads
 # do enable MT running
 if ROOT.gSystem.GetFromPipe("root-config --has-imt") == "yes":
     ROOT.EnableImplicitMT(num_threads)
-    ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")  # switch OFF MT in OpenBLAS
+    # switch off MT in OpenBLAS to avoid conflict with tbb
+    ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")
     print("Running with nthreads  = {}".format(ROOT.GetThreadPoolSize()))
 else:
-    print("Running in serail mode since ROOT does not support MT")
+    print("Running in serial mode since ROOT does not support MT")
 
 
 TMVA = ROOT.TMVA
-- 
2.39.2