@@ -154,19 +154,43 @@ class MATHS_COMMON_EXPORT CNaiveBayesFeatureDensityFromPrior final
154
154
TPriorPtr m_Prior;
155
155
};
156
156
157
+ // ! \brief Enables using custom feature weights in class prediction.
158
+ class CNaiveBayesFeatureWeight {
159
+ public:
160
+ virtual ~CNaiveBayesFeatureWeight () = default ;
161
+ virtual void add (std::size_t class_, double logLikelihood) = 0;
162
+ virtual double calculate () const = 0;
163
+ };
164
+
157
165
// ! \brief Implements a Naive Bayes classifier.
158
166
class MATHS_COMMON_EXPORT CNaiveBayes {
159
167
public:
168
+ using TDoubleDoublePr = std::pair<double , double >;
160
169
using TDoubleSizePr = std::pair<double , std::size_t >;
161
170
using TDoubleSizePrVec = std::vector<TDoubleSizePr>;
171
+ using TDoubleSizePrVecDoublePr = std::pair<TDoubleSizePrVec, double >;
162
172
using TDouble1Vec = core::CSmallVector<double , 1 >;
163
173
using TDouble1VecVec = std::vector<TDouble1Vec>;
164
- using TOptionalDouble = std::optional<double >;
174
+ using TFeatureWeightProvider = std::function<CNaiveBayesFeatureWeight&()>;
175
+
176
+ private:
177
+ // ! \brief All features have unit weight in class prediction.
178
+ class CUnitFeatureWeight : public CNaiveBayesFeatureWeight {
179
+ public:
180
+ void add (std::size_t , double ) override {}
181
+ double calculate () const override { return 1.0 ; }
182
+ };
183
+
184
+ class CUnitFeatureWeightProvider {
185
+ public:
186
+ CUnitFeatureWeight& operator ()() const { return m_UnitWeight; }
187
+
188
+ private:
189
+ mutable CUnitFeatureWeight m_UnitWeight;
190
+ };
165
191
166
192
public:
167
- explicit CNaiveBayes (const CNaiveBayesFeatureDensity& exemplar,
168
- double decayRate = 0.0 ,
169
- TOptionalDouble minMaxLogLikelihoodToUseFeature = TOptionalDouble());
193
+ explicit CNaiveBayes (const CNaiveBayesFeatureDensity& exemplar, double decayRate = 0.0 );
170
194
CNaiveBayes (const CNaiveBayesFeatureDensity& exemplar,
171
195
const SDistributionRestoreParams& params,
172
196
core::CStateRestoreTraverser& traverser);
@@ -184,6 +208,9 @@ class MATHS_COMMON_EXPORT CNaiveBayes {
184
208
// ! Check if any training data has been added initialized.
185
209
bool initialized () const ;
186
210
211
+ // ! Get the number of classes.
212
+ std::size_t numberClasses () const ;
213
+
187
214
// ! This can be used to optionally seed the class counts
188
215
// ! with \p counts. These are added on to data class counts
189
216
// ! to compute the class posterior probabilities.
@@ -210,27 +237,53 @@ class MATHS_COMMON_EXPORT CNaiveBayes {
210
237
// !
211
238
// ! \param[in] n The number of class probabilities to estimate.
212
239
// ! \param[in] x The feature values.
240
+ // ! \param[in] weightProvider Computes a feature weight from the class
241
+ // ! conditional log-likelihood of the feature value. It should be in
242
+ // ! the range [0,1]. The smaller the value the less impact the feature
243
+ // ! has on class selection.
244
+ // ! \return The class probabilities and the minimum feature weight.
213
245
// ! \note \p x size should be equal to the number of features.
214
246
// ! A feature is missing is indicated by passing an empty vector
215
247
// ! for that feature.
216
- TDoubleSizePrVec highestClassProbabilities (std::size_t n, const TDouble1VecVec& x) const ;
248
+ TDoubleSizePrVecDoublePr highestClassProbabilities (
249
+ std::size_t n,
250
+ const TDouble1VecVec& x,
251
+ const TFeatureWeightProvider& weightProvider = CUnitFeatureWeightProvider{}) const ;
217
252
218
253
// ! Get the probability of the class labeled \p label for \p x.
219
254
// !
220
255
// ! \param[in] label The label of the class of interest.
221
256
// ! \param[in] x The feature values.
257
+ // ! \param[in] weightProvider Computes a feature weight from the class
258
+ // ! conditional log-likelihood of the feature value. It should be in
259
+ // ! the range [0,1]. The smaller the value the less impact the feature
260
+ // ! has on class selection.
261
+ // ! \return The class probabilities and the minimum feature weight.
262
+ // ! conditional distributions.
222
263
// ! \note \p x size should be equal to the number of features.
223
264
// ! A feature is missing is indicated by passing an empty vector
224
265
// ! for that feature.
225
- double classProbability (std::size_t label, const TDouble1VecVec& x) const ;
266
+ TDoubleDoublePr classProbability (std::size_t label,
267
+ const TDouble1VecVec& x,
268
+ const TFeatureWeightProvider& weightProvider =
269
+ CUnitFeatureWeightProvider{}) const ;
226
270
227
271
// ! Get the probabilities of all the classes for \p x.
228
272
// !
229
273
// ! \param[in] x The feature values.
274
+ // ! \param[in] weightProvider Computes a feature weight from the class
275
+ // ! conditional log-likelihood of the feature value. It should be in
276
+ // ! the range [0,1]. The smaller the value the less impact the feature
277
+ // ! has on class selection.
278
+ // ! \return The class probabilities and the minimum feature weight.
279
+ // ! A feature is missing is indicated by passing an empty vector
280
+ // ! for that feature.
230
281
// ! \note \p x size should be equal to the number of features.
231
282
// ! A feature is missing is indicated by passing an empty vector
232
283
// ! for that feature.
233
- TDoubleSizePrVec classProbabilities (const TDouble1VecVec& x) const ;
284
+ TDoubleSizePrVecDoublePr
285
+ classProbabilities (const TDouble1VecVec& x,
286
+ const TFeatureWeightProvider& weightProvider = CUnitFeatureWeightProvider{}) const ;
234
287
235
288
// ! Debug the memory used by this object.
236
289
void debugMemoryUsage (const core::CMemoryUsage::TMemoryUsagePtr& mem) const ;
@@ -298,13 +351,6 @@ class MATHS_COMMON_EXPORT CNaiveBayes {
298
351
bool validate (const TDouble1VecVec& x) const ;
299
352
300
353
private:
301
- // ! It is not always appropriate to use features with very low
302
- // ! probability in all classes to discriminate: the class choice
303
- // ! will be very sensitive to the underlying conditional density
304
- // ! model. This is a cutoff (for the minimum maximum class log
305
- // ! likelihood) in order to use a feature.
306
- TOptionalDouble m_MinMaxLogLikelihoodToUseFeature;
307
-
308
354
// ! Controls the rate at which data are aged out.
309
355
double m_DecayRate;
310
356
0 commit comments