presage  0.9.1
smoothedNgramPredictor.cpp
Go to the documentation of this file.
1 
2 /******************************************************
3  * Presage, an extensible predictive text entry system
4  * ---------------------------------------------------
5  *
6  * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7 
8  This program is free software; you can redistribute it and/or modify
9  it under the terms of the GNU General Public License as published by
10  the Free Software Foundation; either version 2 of the License, or
11  (at your option) any later version.
12 
13  This program is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  GNU General Public License for more details.
17 
18  You should have received a copy of the GNU General Public License along
19  with this program; if not, write to the Free Software Foundation, Inc.,
20  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  *
22  **********(*)*/
23 
24 
25 #include "smoothedNgramPredictor.h"
26 
27 #include <sstream>
28 #include <algorithm>
29 
30 
32  : Predictor(config,
33  ct,
34  name,
35  "SmoothedNgramPredictor, a linear interpolating n-gram predictor",
36  "SmoothedNgramPredictor, long description." ),
37  db (0),
38  cardinality (0),
39  learn_mode_set (false),
40  dispatcher (this)
41 {
42  LOGGER = PREDICTORS + name + ".LOGGER";
43  DBFILENAME = PREDICTORS + name + ".DBFILENAME";
44  DELTAS = PREDICTORS + name + ".DELTAS";
45  LEARN = PREDICTORS + name + ".LEARN";
46  DATABASE_LOGGER = PREDICTORS + name + ".DatabaseConnector.LOGGER";
47 
48  // build notification dispatch map
54 }
55 
56 
57 
59 {
60  delete db;
61 }
62 
63 
64 void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
65 {
66  dbfilename = filename;
67  logger << INFO << "DBFILENAME: " << dbfilename << endl;
68 
70 }
71 
72 
74 {
75  dbloglevel = value;
76 }
77 
78 
79 void SmoothedNgramPredictor::set_deltas (const std::string& value)
80 {
81  std::stringstream ss_deltas(value);
82  cardinality = 0;
83  std::string delta;
84  while (ss_deltas >> delta) {
85  logger << DEBUG << "Pushing delta: " << delta << endl;
86  deltas.push_back (Utility::toDouble (delta));
87  cardinality++;
88  }
89  logger << INFO << "DELTAS: " << value << endl;
90  logger << INFO << "CARDINALITY: " << cardinality << endl;
91 
93 }
94 
95 
96 void SmoothedNgramPredictor::set_learn (const std::string& value)
97 {
98  learn_mode = Utility::isYes (value);
99  logger << INFO << "LEARN: " << value << endl;
100 
101  learn_mode_set = true;
102 
104 }
105 
106 
108 {
109  // we can only init the sqlite database connector once we know the
110  // following:
111  // - what database file we need to open
112  // - what cardinality we expect the database file to be
113  // - whether we need to open the database in read only or
114  // read/write mode (learning requires read/write access)
115  //
116  if (! dbfilename.empty()
117  && cardinality > 0
118  && learn_mode_set ) {
119 
120  delete db;
121 
122  if (dbloglevel.empty ()) {
123  // open database connector
125  cardinality,
126  learn_mode);
127  } else {
128  // open database connector with logger lever
130  cardinality,
131  learn_mode,
132  dbloglevel);
133  }
134  }
135 }
136 
137 
138 // convenience function to convert ngram to string
139 //
140 static std::string ngram_to_string(const Ngram& ngram)
141 {
142  const char separator[] = "|";
143  std::string result = separator;
144 
145  for (Ngram::const_iterator it = ngram.begin();
146  it != ngram.end();
147  it++)
148  {
149  result += *it + separator;
150  }
151 
152  return result;
153 }
154 
155 
171 unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
172 {
173  unsigned int result = 0;
174 
175  assert(offset <= 0); // TODO: handle this better
176  assert(ngram_size >= 0);
177 
178  if (ngram_size > 0) {
179  Ngram ngram(ngram_size);
180  copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
181  result = db->getNgramCount(ngram);
182  logger << DEBUG << "count ngram: " << ngram_to_string (ngram) << " : " << result << endl;
183  } else {
184  result = db->getUnigramCountsSum();
185  logger << DEBUG << "unigram counts sum: " << result << endl;
186  }
187 
188  return result;
189 }
190 
191 Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
192 {
193  logger << DEBUG << "predict()" << endl;
194 
195  // Result prediction
196  Prediction prediction;
197 
198  // Cache all the needed tokens.
199  // tokens[k] corresponds to w_{i-k} in the generalized smoothed
200  // n-gram probability formula
201  //
202  std::vector<std::string> tokens(cardinality);
203  for (int i = 0; i < cardinality; i++) {
204  tokens[cardinality - 1 - i] = contextTracker->getToken(i);
205  logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
206  }
207 
208  // Generate list of prefix completition candidates.
209  //
210  // The prefix completion candidates used to be obtained from the
211  // _1_gram table because in a well-constructed ngram database the
212  // _1_gram table (which contains all known tokens). However, this
213  // introduced a skew, since the unigram counts will take
214  // precedence over the higher-order counts.
215  //
216  // The current solution retrieves candidates from the highest
217  // n-gram table, falling back on lower order n-gram tables if
218  // initial completion set is smaller than required.
219  //
220  std::vector<std::string> prefixCompletionCandidates;
221  for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
222  logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
223  // create n-gram used to retrieve initial prefix completion table
224  Ngram prefix_ngram(k);
225  copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
226 
227  if (logger.shouldLog()) {
228  logger << DEBUG << "prefix_ngram: ";
229  for (size_t r = 0; r < prefix_ngram.size(); r++) {
230  logger << DEBUG << prefix_ngram[r] << ' ';
231  }
232  logger << DEBUG << endl;
233  }
234 
235  // obtain initial prefix completion candidates
236  db->beginTransaction();
237 
238  NgramTable partial;
239 
240  if (filter == 0) {
241  partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size());
242  } else {
243  partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size());
244  }
245 
246  db->endTransaction();
247 
248  if (logger.shouldLog()) {
249  logger << DEBUG << "partial prefixCompletionCandidates" << endl
250  << DEBUG << "----------------------------------" << endl;
251  for (size_t j = 0; j < partial.size(); j++) {
252  for (size_t k = 0; k < partial[j].size(); k++) {
253  logger << DEBUG << partial[j][k] << " ";
254  }
255  logger << endl;
256  }
257  }
258 
259  logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
260 
261  // append newly discovered potential completions to prefix
262  // completion candidates array to fill it up to
263  // max_partial_prediction_size
264  //
265  std::vector<Ngram>::const_iterator it = partial.begin();
266  while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
267  // only add new candidates, iterator it points to Ngram,
268  // it->end() - 2 points to the token candidate
269  //
270  std::string candidate = *(it->end() - 2);
271  if (find(prefixCompletionCandidates.begin(),
272  prefixCompletionCandidates.end(),
273  candidate) == prefixCompletionCandidates.end()) {
274  prefixCompletionCandidates.push_back(candidate);
275  }
276  it++;
277  }
278  }
279 
280  if (logger.shouldLog()) {
281  logger << DEBUG << "prefixCompletionCandidates" << endl
282  << DEBUG << "--------------------------" << endl;
283  for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
284  logger << DEBUG << prefixCompletionCandidates[j] << endl;
285  }
286  }
287 
288  // compute smoothed probabilities for all candidates
289  //
290  db->beginTransaction();
291  // getUnigramCountsSum is an expensive SQL query
292  // caching it here saves much time later inside the loop
293  int unigrams_counts_sum = db->getUnigramCountsSum();
294  for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
295  // store w_i candidate at end of tokens
296  tokens[cardinality - 1] = prefixCompletionCandidates[j];
297 
298  logger << DEBUG << "------------------" << endl;
299  logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
300 
301  double probability = 0;
302  for (int k = 0; k < cardinality; k++) {
303  double numerator = count(tokens, 0, k+1);
304  // reuse cached unigrams_counts_sum to speed things up
305  double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
306  double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
307  probability += deltas[k] * frequency;
308 
309  logger << DEBUG << "numerator: " << numerator << endl;
310  logger << DEBUG << "denominator: " << denominator << endl;
311  logger << DEBUG << "frequency: " << frequency << endl;
312  logger << DEBUG << "delta: " << deltas[k] << endl;
313 
314  // for some sanity checks
315  assert(numerator <= denominator);
316  assert(frequency <= 1);
317  }
318 
319  logger << DEBUG << "____________" << endl;
320  logger << DEBUG << "probability: " << probability << endl;
321 
322  if (probability > 0) {
323  prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
324  }
325  }
326  db->endTransaction();
327 
328  logger << DEBUG << "Prediction:" << endl;
329  logger << DEBUG << "-----------" << endl;
330  logger << DEBUG << prediction << endl;
331 
332  return prediction;
333 }
334 
335 void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
336 {
337  logger << INFO << "learn(\"" << ngram_to_string(change) << "\")" << endl;
338 
339  if (learn_mode) {
340  // learning is turned on
341 
342  std::map<std::list<std::string>, int> ngramMap;
343 
344  // build up ngram map for all cardinalities
345  // i.e. learn all ngrams and counts in memory
346  for (size_t curr_cardinality = 1;
347  curr_cardinality < cardinality + 1;
348  curr_cardinality++)
349  {
350  int change_idx = 0;
351  int change_size = change.size();
352 
353  std::list<std::string> ngram_list;
354 
355  // take care of first N-1 tokens
356  for (int i = 0;
357  (i < curr_cardinality - 1 && change_idx < change_size);
358  i++)
359  {
360  ngram_list.push_back(change[change_idx]);
361  change_idx++;
362  }
363 
364  while (change_idx < change_size)
365  {
366  ngram_list.push_back(change[change_idx++]);
367  ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
368  ngram_list.pop_front();
369  }
370  }
371 
372  // use (past stream - change) to learn token at the boundary
373  // change, i.e.
374  //
375 
376  // if change is "bar foobar", then "bar" will only occur in a
377  // 1-gram, since there are no token before it. By dipping in
378  // the past stream, we additional context to learn a 2-gram by
379  // getting extra tokens (assuming past stream ends with token
380  // "foo":
381  //
382  // <"foo", "bar"> will be learnt
383  //
384  // We do this till we build up to n equal to cardinality.
385  //
386  // First check that change is not empty (nothing to learn) and
387  // that change and past stream match by sampling first and
388  // last token in change and comparing them with corresponding
389  // tokens from past stream
390  //
391  if (change.size() > 0 &&
392  change.back() == contextTracker->getToken(1) &&
393  change.front() == contextTracker->getToken(change.size()))
394  {
395  // create ngram list with first (oldest) token from change
396  std::list<std::string> ngram_list(change.begin(), change.begin() + 1);
397 
398  // prepend token to ngram list by grabbing extra tokens
399  // from past stream (if there are any) till we have built
400  // up to n==cardinality ngrams, and commit them to
401  // ngramMap
402  //
403  for (int tk_idx = 1;
404  ngram_list.size() < cardinality;
405  tk_idx++)
406  {
407  // getExtraTokenToLearn returns tokens from
408  // past stream that come before and are not in
409  // change vector
410  //
411  std::string extra_token = contextTracker->getExtraTokenToLearn(tk_idx, change);
412  logger << DEBUG << "Adding extra token: " << extra_token << endl;
413 
414  if (extra_token.empty())
415  {
416  break;
417  }
418  ngram_list.push_front(extra_token);
419 
420  ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
421  }
422  }
423 
424  // then write out to language model database
425  try
426  {
427  db->beginTransaction();
428 
429  std::map<std::list<std::string>, int>::const_iterator it;
430  for (it = ngramMap.begin(); it != ngramMap.end(); it++)
431  {
432  // convert ngram from list to vector based Ngram
433  Ngram ngram((it->first).begin(), (it->first).end());
434 
435  // update the counts
436  int count = db->getNgramCount(ngram);
437  if (count > 0)
438  {
439  // ngram already in database, update count
440  db->updateNgram(ngram, count + it->second);
442  }
443  else
444  {
445  // ngram not in database, insert it
446  db->insertNgram(ngram, it->second);
447  }
448  }
449 
450  db->endTransaction();
451  logger << INFO << "Committed learning update to database" << endl;
452  }
454  {
456  logger << ERROR << "Rolling back learning update : " << ex.what() << endl;
457  throw;
458  }
459  }
460 
461  logger << DEBUG << "end learn()" << endl;
462 }
463 
465 {
466  // no need to begin a new transaction, as we'll be called from
467  // within an existing transaction from learn()
468 
469  // BEWARE: if the previous sentence is not true, then performance
470  // WILL suffer!
471 
472  size_t size = ngram.size();
473  for (size_t i = 0; i < size; i++) {
474  if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
475  logger << INFO << "consistency adjustment needed!" << endl;
476 
477  int offset = -(i + 1);
478  int sub_ngram_size = size - (i + 1);
479 
480  logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
481 
482  Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
483  copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
484 
485  if (logger.shouldLog()) {
486  logger << "ngram to be count adjusted is: ";
487  for (size_t i = 0; i < sub_ngram.size(); i++) {
488  logger << sub_ngram[i] << ' ';
489  }
490  logger << endl;
491  }
492 
493  db->incrementNgramCount(sub_ngram);
494  logger << DEBUG << "consistency adjusted" << endl;
495  }
496  }
497 }
498 
500 {
501  logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
502  dispatcher.dispatch (var);
503 }
Tracks user interaction and context.
std::string getExtraTokenToLearn(const int index, const std::vector< std::string > &change) const
std::string getToken(const int) const
virtual void endTransaction() const
virtual void beginTransaction() const
virtual void rollbackTransaction() const
NgramTable getNgramLikeTable(const Ngram ngram, int limit=-1) const
NgramTable getNgramLikeTableFiltered(const Ngram ngram, const char **filter, int limit=-1) const
int incrementNgramCount(const Ngram ngram) const
void insertNgram(const Ngram ngram, const int count) const
int getUnigramCountsSum() const
int getNgramCount(const Ngram ngram) const
void updateNgram(const Ngram ngram, const int count) const
void dispatch(const Observable *var)
Definition: dispatcher.h:73
void map(Observable *var, const mbr_func_ptr_t &ptr)
Definition: dispatcher.h:62
bool shouldLog() const
Definition: logger.h:149
Definition: ngram.h:33
virtual std::string get_name() const =0
virtual std::string get_value() const =0
void addSuggestion(Suggestion)
Definition: prediction.cpp:90
ContextTracker * contextTracker
Definition: predictor.h:83
const std::string PREDICTORS
Definition: predictor.h:81
virtual void set_logger(const std::string &level)
Definition: predictor.cpp:88
Logger< char > logger
Definition: predictor.h:87
const std::string name
Definition: predictor.h:77
virtual const char * what() const
void check_learn_consistency(const Ngram &name) const
Dispatcher< SmoothedNgramPredictor > dispatcher
std::vector< double > deltas
void set_database_logger_level(const std::string &level)
virtual void learn(const std::vector< std::string > &change)
unsigned int count(const std::vector< std::string > &tokens, int offset, int ngram_size) const
Builds the required n-gram and returns its count.
virtual void update(const Observable *variable)
void set_dbfilename(const std::string &filename)
void set_learn(const std::string &learn_mode)
SmoothedNgramPredictor(Configuration *, ContextTracker *, const char *)
virtual Prediction predict(const size_t size, const char **filter) const
Generate prediction.
void set_deltas(const std::string &deltas)
static double toDouble(const std::string)
Definition: utility.cpp:258
static bool isYes(const char *)
Definition: utility.cpp:185
std::vector< Ngram > NgramTable
const Logger< _charT, _Traits > & endl(const Logger< _charT, _Traits > &lgr)
Definition: logger.h:278
std::string config
Definition: presageDemo.cpp:70
static std::string ngram_to_string(const Ngram &ngram)