00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #include "phrasehunter/statistics.h"
00022 #include "phrasehunter/contextreader.h"
00023 #include "phrasehunter/token.h"
00024 #include "phrasehunter/tokencontext.h"
00025
00026 #include <algorithm>
00027 #include <map>
00028 #include <iterator>
00029 #include <cmath>
00030 #include <set>
00031 #include <tr1/unordered_set>
00032 #include <vector>
00033
00034 #include <iostream>
00035
00036 using namespace SQLitePP;
00037
00038 namespace PhraseHunter {
00039
00040 struct RankingCriterion
00041 {
00042 bool operator() (TokenPtr t1, TokenPtr t2) const
00043 {
00044 return t1->corpusFrequency() > t2->corpusFrequency();
00045 }
00046 };
00047
00048
00049
00050
00051
00052 template<typename T>
00053 struct Smaller
00054 {
00055 bool operator() (const T& c1, const T& c2 ) const
00056 {
00057 return c1.size() < c2.size();
00058 }
00059 };
00060
00061 StatisticsEngine::StatisticsEngine(SearchEngine* searcher, ContextReader* reader, SqliteDB& db)
00062 : m_searcher(searcher),
00063 m_reader(reader),
00064 m_db(db)
00065 {
00066 m_sizeOfSampleSpace = 0;
00067 m_numberOfTypes = 0;
00068
00069 Statement::Pointer stmt = m_db.statement("select word, frequency from tokens");
00070 unsigned long totalLength = 0;
00071 for(ResultIterator ri = stmt->query(); ri.hasMoreRows(); ri.next()) {
00072 totalLength += ri.get<schma::UnicodePtr>(0)->length() * ri.get<int>(1);
00073
00074 m_sizeOfSampleSpace += ri.get<int>(1);
00075 ++m_numberOfTypes;
00076 }
00077 m_averageWordLength = static_cast<double>(totalLength) / static_cast<double>(m_sizeOfSampleSpace);
00078
00079 m_contextWindow = static_cast<unsigned>(m_averageWordLength*3);
00080
00081 Statement::Pointer numDoc = m_db.statement("select sum(length(id))/4 from docs");
00082 ResultIterator ri = numDoc->query();
00083 m_numberOfDocuments = ri.get<int>(0);
00084 }
00085
00086 unsigned int StatisticsEngine::rank(TokenPtr t) const
00087 {
00088 if (m_searcher == NULL)
00089 return 0;
00090
00091 Statement::Pointer smt = m_db.statement("select rank from tokens where id = ?");
00092 ResultIterator ri = (*smt)
00093 .bindArg(t->id())
00094 .query();
00095 return ri.get<int>(0);
00096 }
00097
00098 unsigned int StatisticsEngine::overallFrequency(TokenPtr t) const
00099 {
00100 if (m_searcher == NULL)
00101 return 0;
00102
00103 Statement::Pointer smt = m_db.statement("select frequency from tokens where id = ?");
00104 ResultIterator ri = (*smt)
00105 .bindArg(t->id())
00106 .query();
00107 return ri.get<int>(0);
00108 }
00109
00110 unsigned int StatisticsEngine::frequency_x_intersection_y(TokenPtr x, TokenPtr y) const
00111 {
00112 unsigned comm_occs = 0;
00113 const OccurrenceMap& shorter = std::min(x->allOccurrences(), y->allOccurrences(),
00114 Smaller<OccurrenceMap>());
00115 const OccurrenceMap& longer = std::max(x->allOccurrences(), y->allOccurrences(),
00116 Smaller<OccurrenceMap>());
00117
00118 for(OccurrenceMap::const_iterator shorter_it = shorter.begin();
00119 shorter_it != shorter.end(); ++shorter_it) {
00120
00121 OccurrenceMap::const_iterator longer_it = longer.find(shorter_it->first);
00122
00123 if (longer_it != longer.end()) {
00124 comm_occs += freq_intersection_in_doc(shorter_it->second, longer_it->second);
00125 }
00126 }
00127 return comm_occs;
00128 }
00129
00130 unsigned StatisticsEngine::freq_intersection_in_doc(const PositionList& p1, const PositionList& p2) const
00131 {
00132 const PositionList& shorter = min(p1, p2, Smaller<PositionList>());
00133 const PositionList& longer = max(p1, p2, Smaller<PositionList>());
00134
00135 unsigned comm_occs = 0;
00136 PositionList::const_iterator i1, i2;
00137 for (PositionList::const_iterator it = shorter.begin();
00138 it != shorter.end(); ++it) {
00139
00140 if (static_cast<int>(*it) - static_cast<int>(m_contextWindow) > 0)
00141 i1 = std::lower_bound(longer.begin(), longer.end(), *it-m_contextWindow);
00142 else
00143 i1 = longer.begin();
00144
00145 i2 = std::lower_bound(longer.begin(), longer.end(), *it+m_contextWindow);
00146
00147 if (i2 != i1)
00148 comm_occs += i2 - i1;
00149 }
00150 return comm_occs;
00151 }
00152
00153 double StatisticsEngine::relative_frequency(TokenPtr x) const
00154 {
00155 if (getSizeOfSampleSpace())
00156 return overallFrequency(x)/static_cast<double>(getSizeOfSampleSpace());
00157 else
00158 return 0.0;
00159 }
00160
00161 double StatisticsEngine::mle_x_intersection_y(TokenPtr x, TokenPtr y) const
00162 {
00163 if (getSizeOfSampleSpace())
00164 return static_cast<double>(frequency_x_intersection_y(x,y))/static_cast<double>(getSizeOfSampleSpace());
00165 else
00166 return 0.0;
00167 }
00168
00169 double StatisticsEngine::mle_x_given_y(TokenPtr x, TokenPtr y) const
00170 {
00171 double p_of_x = relative_frequency(y);
00172 if (p_of_x)
00173 return mle_x_intersection_y(x,y)/p_of_x;
00174 else
00175 return 0.0;
00176 }
00177
00178 std::vector<TokenPtr> StatisticsEngine::getCandidates(TokenPtr x) const
00179 {
00180 OccurrenceMap occs = x->allOccurrences();
00181
00182 std::set<TokenID> tokenIds;
00183
00184 Statement::Pointer getWordsInRange = m_db.statement("SELECT wordid FROM occurrences WHERE docID = ? and inrange(positions, ?, ?)");
00185
00186 for(OccurrenceMap::const_iterator itOccs = occs.begin(); itOccs != occs.end(); ++itOccs) {
00187
00188 for(PositionList::const_iterator itPos=itOccs->second.begin();
00189 itPos!=itOccs->second.end(); ++itPos) {
00190
00191 std::pair<IdxPos, IdxPos> bounds = contextBounds(*itPos, getContextSize());
00192
00193 ResultIterator ri = (*getWordsInRange)
00194 .bindArgs(itOccs->first, bounds.first, bounds.second)
00195 .query();
00196
00197 while(ri.hasMoreRows()) {
00198 tokenIds.insert(ri.get<int>(0));
00199 }
00200 }
00201 }
00202
00203 std::vector<TokenPtr> candidateStrings;
00204
00205 Statement::Pointer getTokenString = m_db.statement("SELECT word FROM tokens WHERE id = ?");
00206
00207 for(std::set<TokenID>::const_iterator tokIt = tokenIds.begin();
00208 tokIt != tokenIds.end();
00209 ++tokIt) {
00210
00211 ResultIterator ri = (*getTokenString)
00212 .bindArg(*tokIt)
00213 .query();
00214
00215 if(ri.hasMoreRows() && strlen(ri.get<const char*>(0)) > 3)
00216 candidateStrings.push_back(CorpusToken::loadFromCorpus(ri.get<schma::UnicodePtr>(0), *tokIt, m_db));
00217 }
00218 return candidateStrings;
00219 }
00220
00221 std::pair<IdxPos, IdxPos> StatisticsEngine::contextBounds(IdxPos position, int numWords) const
00222 {
00223 int lower = static_cast<int>(position) - (numWords*static_cast<int>(m_averageWordLength)+numWords);
00224 if (lower < 0)
00225 lower = 0;
00226 IdxPos upper = position + (numWords*static_cast<int>(m_averageWordLength)+numWords);
00227 return std::make_pair(static_cast<IdxPos>(lower), upper);
00228 }
00229
00230 class DocSet
00231 {
00232 std::set<DocID> docSet;
00233 public:
00234 void operator()(TokenPtr t)
00235 {
00236 std::vector<DocID> docs = t->documentIDs();
00237 docSet.insert(docs.begin(), docs.end());
00238 }
00239 operator unsigned int() const
00240 {
00241 return docSet.size();
00242 }
00243 };
00244
00245 unsigned int StatisticsEngine::sizeOfDocSet(const std::vector<TokenPtr>& tokens)
00246 {
00247 return std::for_each(tokens.begin(), tokens.end(), DocSet());
00248 }
00249
00250 std::map<schma::UnicodePtr, int> StatisticsEngine::getContextVector(TokenPtr t, int contextlen) const
00251 {
00252 std::map<schma::UnicodePtr, int> results;
00253 if (m_searcher == NULL)
00254 return results;
00255
00256 typedef std::vector<TokenContextPtr> contextVector;
00257 contextVector stringContext = m_reader->context(t, static_cast<int>(contextlen*m_averageWordLength));
00258
00259 for (contextVector::const_iterator it = stringContext.begin();
00260 it != stringContext.end(); ++it) {
00261
00262 schma::UnicodePtr left = (*it)->leftContext();
00263 IdxPos spaceIdx = left->indexOf(' ');
00264 schma::UnicodePtr cutContext(new UnicodeString);
00265 left->extract(spaceIdx, left->length()-spaceIdx, *cutContext);
00266
00267 if (cutContext->length() <= 3)
00268 continue;
00269
00270 int freq = ++results[cutContext];
00271
00272 results.insert(std::make_pair(cutContext, freq));
00273 }
00274 return results;
00275 }
00276
00277 double StatisticsEngine::mutual_information(const TokenPtr& token, const std::map<schma::UnicodePtr,int>& contextVec, schma::UnicodePtr context) const
00278 {
00279 TokenPtr contextToken;
00280 if (token->tokenString()->indexOf(' ') != -1)
00281 contextToken = m_searcher->searchPhrase(context);
00282 else
00283 contextToken = m_searcher->searchToken(context);
00284
00285 if (contextToken == NULL)
00286 return 0;
00287
00288 double frequency_w_in_c;
00289
00290 std::map<schma::UnicodePtr,int>::const_iterator found = contextVec.find(context);
00291 if (found != contextVec.end())
00292 frequency_w_in_c = found->second;
00293 else
00294 frequency_w_in_c = 0;
00295
00296 double numerator = frequency_w_in_c / static_cast<double>(getSizeOfSampleSpace());
00297 double denominator = relative_frequency(token)*relative_frequency(contextToken);
00298
00299 double discounting_factor =
00300 frequency_w_in_c/(frequency_w_in_c+1)
00301 *
00302 (std::min(overallFrequency(token), overallFrequency(contextToken))
00303 /
00304 (std::min(overallFrequency(token), overallFrequency(contextToken))+1));
00305
00306 return discounting_factor*(numerator/denominator);
00307 }
00308
00309 double StatisticsEngine::similarity(TokenPtr t1, TokenPtr t2, int contextlen) const
00310 {
00311 typedef std::map<schma::UnicodePtr,int> UnicodeFreqMap;
00312
00313 UnicodeFreqMap c1 = getContextVector(t1, contextlen);
00314 UnicodeFreqMap c2 = getContextVector(t2, contextlen);
00315
00316 UnicodeFreqMap* shorter = c1.size() <= c2.size() ? &c1 : &c2;
00317 double numerator_sum = 0;
00318 for (UnicodeFreqMap::const_iterator it = shorter->begin();
00319 it != shorter->end(); ++it) {
00320 double mi = mutual_information(t1, *shorter, it->first);
00321 numerator_sum += mi * mi;
00322 }
00323
00324 double denominator_first_sum = 0;
00325
00326 for (UnicodeFreqMap::const_iterator it = c1.begin(); it != c1.end(); ++it)
00327 denominator_first_sum += mutual_information(t1,c1,it->first);
00328
00329 double denominator_second_sum = 0;
00330 for (UnicodeFreqMap::const_iterator it = c2.begin(); it != c2.end(); ++it)
00331 denominator_first_sum += mutual_information(t2,c2,it->first);
00332
00333 return numerator_sum/std::sqrt((denominator_first_sum * denominator_first_sum)
00334 *
00335 (denominator_second_sum * denominator_second_sum));
00336 }
00337
00338 std::multimap<double,TokenPtr> StatisticsEngine::similarTokens(TokenPtr t, int contextlen) const
00339 {
00340 std::multimap<double,TokenPtr> sorted;
00341
00342 if (t == NULL)
00343 return sorted;
00344
00345 TokenVector candidates = getCandidates(t);
00346
00347 for (TokenVector::const_iterator it = candidates.begin();
00348 it != candidates.end(); ++it) {
00349 sorted.insert(std::make_pair(similarity(t, *it, contextlen), *it));
00350 }
00351 return sorted;
00352 }
00353
00354 }