Gibbs Sampling实现LDA

论坛 期权论坛 脚本     
匿名技术用户   2021-1-6 08:48   125   0


关于LDA的介绍见前面几篇文章,这里是Gibbs抽样解LDA的实现


可以看到收敛之后主题的结果基本不变

package org.jazywoo.lda;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Document {
 private String docName;
 private List<Integer> words; //词对应的termID
    
    public Document(String docName) {
     this.docName=docName;
    }

 public String getDocName() {
  return docName;
 }

 public void setDocName(String docName) {
  this.docName = docName;
 }

 public List<Integer> getWords() {
  return words;
 }

 public void setWords(List<Integer> words) {
  this.words = words;
 }

 
    

 
}

package org.jazywoo.lda;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.jazywoo.tokenization.Tokenization;

import ICTCLAS.I3S.AC.ICTCLAS50;

public class Corpus {
 private List<Document> docs;  //文档
 private Map<String, Integer> termIndexMap;//词--序号
 private List<String> terms;
 private Map<String, Integer> termCountMap;//词频
    
    public Corpus() {
     docs = new ArrayList<Document>();
     termIndexMap = new HashMap<String, Integer>();
     terms = new ArrayList<String>();
        termCountMap = new HashMap<String, Integer>();
 }
    
    public void loadData(String path) throws IOException{
      File folder=new File(path);
      if(folder.exists()){
       File[] files=folder.listFiles();
       for(File f:files){
                 BufferedReader br = new BufferedReader(new FileReader(f));
                 String line = "";
                 StringBuffer buf=new StringBuffer();
                 while ((line = br.readLine()) != null) {
                     buf.append(line+" ");
                 }
                 addDocument("doc", buf.toString());
       }
      }
    }
    
    private void addDocument(String docName, String content){
     Document document=new Document(docName);
     String[] words=getWordsFromSentence(content);
     List<Integer> wordsList=new ArrayList<Integer>();
     int termCount=0;
     for(int i=0;i<words.length;++i){
      String term=words[i];
      if(termIndexMap.containsKey(term)){
       termCountMap.put(term, termCountMap.get(term)+1);
      }else{//不存在该词
       int index=termIndexMap.size();
       termIndexMap.put(term, index);
       terms.add(term);
       termCountMap.put(term, 0);
      }
      int termID=termIndexMap.get(term);
      wordsList.add(termID);
     }
     document.setWords(wordsList);
     docs.add(document);
     
    }
    /**从句子中得到分词,过滤掉停用词和干扰词
     * @param content
     * @return
     */
    private String[] getWordsFromSentence(String content){
     ICTCLAS50 ictclas=new ICTCLAS50();
     Tokenization tokenization=new Tokenization(ictclas);
     boolean isOK=tokenization.init();
     String[] words=null;
     if(isOK){
      try {
    words=tokenization.getPartedWordsWithoutSimbol(content);
   } catch (UnsupportedEncodingException e) {
    e.printStackTrace();
   }
         tokenization.finish();
     }
     
     return words;
    }

 public List<Document> getDocs() {
  return docs;
 }

 public void setDocs(List<Document> docs) {
  this.docs = docs;
 }

 public Map<String, Integer> getTermIndexMap() {
  return termIndexMap;
 }

 public void setTermIndexMap(Map<String, Integer> termIndexMap) {
  this.termIndexMap = termIndexMap;
 }

 public List<String> getTerms() {
  return terms;
 }

 public void setTerms(List<String> terms) {
  this.terms = terms;
 }

}

package org.jazywoo.lda;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;


/**Gibbs Sampling LDA
 * @author jazywoo
 *
 */
public class LdaModel {
 private Corpus docSet;//处理的文档
 private int[][] doc;// word index array,每个文本中每个词在字典indexToTermMap中的序号
 private int V, K, M;// vocabulary size, topic number, document number
 private int[][] z;// topic label array,每个文本的每个词对应的topic的编号
 private float alpha; // doc-topic dirichlet prior parameter
 private float beta; // topic-word dirichlet prior parameter
 private int[][] nmk;// given document m, count times of topic k. M*K :给定document m中的词,每个topic的使用term词数
 private int[][] nkt;// given topic k, count times of term t. K*V :给定topic k的每个term的使用词数
 private int[] nmkSum;// Sum for each row in nmt,nmySum[m]=n:也就是文档m中word的个数为n
 private int[] nktSum;// Sum for each row in nkt,nkt[k]=n:被指定给topic k的term/word的个数为n
    
    //  两个隐含变量theta和phi分别表示第m个文档下的Topic分布和第k个Topic下词的分布,
    //  前者是k维(k为Topic总数)向量,后者是v维向量(v为词典中term总数)。
 private double[][] theta;// Parameters for doc-topic distribution M*K
 private double[][] phi;// Parameters for topic-word distribution K*V
   
 private int iterations;// Times of iterations
 private int saveStep;// The number of iterations between two saving
 private int beginSaveIters;// Begin save model at this iteration
    
    public LdaModel(LdaModel.ModelParameter parameter) {
     alpha = parameter.alpha;
  beta = parameter.beta;
  iterations = parameter.iteration;
  K = parameter.topicNum;
  saveStep = parameter.saveStep;
  beginSaveIters = parameter.beginSaveIters;
 }
    
    public void initModal(Corpus docSet1) {
  this.docSet=docSet1;
  M = docSet.getDocs().size();
  V = docSet.getTerms().size();
  nmk = new int [M][K];
  nkt = new int[K][V];
  nmkSum = new int[M];
  nktSum = new int[K];
  phi = new double[K][V];
  theta = new double[M][K];
  //初始化   每个文本中每个词在字典indexToTermMap中的序号
  //initialize documents index array
  doc = new int[M][];
  for(int m = 0; m < M; m++){
   //Notice the limit of memory
   int N = docSet.getDocs().get(m).getWords().size();
   doc[m] = new int[N];
   for(int n = 0; n < N; n++){
    doc[m][n] = docSet.getDocs().get(m).getWords().get(n);
   }
  }
  // 初始化 每个文本的每个词对应的topic的编号
  //initialize topic lable z for each word
  z = new int[M][];
  for(int m = 0; m < M; m++){
   int N = docSet.getDocs().get(m).getWords().size();
   z[m] = new int[N];
   for(int n = 0; n < N; n++){
    //初始时随机给文本中的每个单词分配主题z[m][n]_old
    int initTopic = (int)(Math.random() * K);// From 0 to K - 1
    z[m][n] = initTopic;
    //number of words in doc m assigned to topic initTopic add 1
    nmk[m][initTopic]++;
    //number of terms doc[m][n] assigned to topic initTopic add 1
    nkt[initTopic][doc[m][n]]++;
    // total number of words assigned to topic initTopic add 1
    nktSum[initTopic]++;
   }
    // total number of words in document m is N
   nmkSum[m] = N;
  }
 }
    public void inferenceModel() throws IOException {
  if(iterations < saveStep + beginSaveIters){
   System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));
   System.exit(0);
  }
  for(int i = 0; i < iterations; i++){
   System.out.println("Iteration " + i);
   if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){
    //Saving the model
    System.out.println("Saving model at iteration " + i +" ... ");
    //Firstly update parameters
    updateEstimatedParameters();
    //Secondly print model variables
    saveIteratedModel(i);
   }
    // z[][]每个文本的每个词对应的topic的编号
   //Use Gibbs Sampling to update z[][]
   for(int m = 0; m < M; m++){
    int N = docSet.getDocs().get(m).getWords().size();
    for(int n = 0; n < N; n++){
     // Sample from p(z_i|z_-i, w)
     int newTopic = sampleTopicZ(m, n);
     z[m][n] = newTopic;
    }
   }
  }
 }
    /**
     * 初始时随机给文本中的每个单词分配主题z[m][n]_old,(这一步已经在初始化中完成)
     * 然后统计每个主题z下出现term t的数量以及每个文档m下出现主题z中的词的数量,
     * 每一轮计算p(z_i|z_-i, d, w),即排除当前词的主题分配,
     * 根据其他所有词的主题分配估计当前词分配各个主题的概率。
     * 当得到当前词属于所有主题z的概率分布后,
     * 根据这个概率分布为该词sample一个新的主题z[m][n]_new。
     * 然后用同样的方法不断更新下一个词的主题,
     * 直到发现每个文档下Topic分布和每个Topic下词的分布收敛,算法停止,
     * 输出待估计的参数和,最终每个单词的主题也同时得出。
     * 实际应用中会设置最大迭代次数。每一次计算p(z_i|z_-i, d, w)的公式称为Gibbs updating rule.
     * @param m
     * @param n
     * @return
     */
    private int sampleTopicZ(int m, int n) {
  // Sample from p(z_i|z_-i, w) using Gibbs upde rule
  
  //Remove topic label for w_{m,n}
     //首先当前词的主题分配
  int oldTopic = z[m][n];
  nmk[m][oldTopic]--;
  nkt[oldTopic][doc[m][n]]--;
  nmkSum[m]--;
  nktSum[oldTopic]--;
  
  //Compute p(z_i = k|z_-i, d, w)
  //当得到当前文档,当前词属于所有主题z的概率分布
  double [] p = new double[K];
  for(int k = 0; k < K; k++){
   //nkt-给定topic k的每个term的使用词数/nktSum-指定给topic k的term/word的个数
         //nmk-给定document m的每个topic的使用词数/nmkSum-文档m中word的个数
   //Gibbs抽样  P(z|w,alpha,beta) = P(w,z | alpha,beta) / P(w | alpha,beta)
   p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) 
               * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
   //p[k]=phi[k][doc[m][n]]*theta[m][k];
  }
   //为该词分配一个新主题
  //Sample a new topic label for w_{m, n} like roulette
  //Compute cumulated probability for p
  for(int k = 1; k < K; k++){
   p[k] += p[k - 1];
  }
  double u = Math.random() * p[K - 1]; //p[] is unnormalised
  int newTopic;
  for(newTopic = 0; newTopic < K; newTopic++){
   if(u < p[newTopic]){
    break;
   }
  }
  
  //Add new topic label for w_{m, n}
  nmk[m][newTopic]++;
  nkt[newTopic][doc[m][n]]++;
  nmkSum[m]++;
  nktSum[newTopic]++;
  return newTopic;
 }
    
    
    /**估计 文档-主题theta参数,主题-词phi参数
     * theta[m][k]表示第m个文档下的Topic分布,p(z_i|d_j)=p(z_i,d_j)/p(d_j)
     * phi[k][t]表示第k个Topic下词的分布p(w_i|z_j)
     */
    private void updateEstimatedParameters() {
  for(int k = 0; k < K; k++){
   for(int t = 0; t < V; t++){
    phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);
    //给定topic k的每个term的使用词数/指定给topic k的term的个数
   }
  }
  
  for(int m = 0; m < M; m++){
   for(int k = 0; k < K; k++){
    theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
    //给定document m的每个topic的使用词数/文档m中word的个数
   }
  }
 }
    
    /**用于保存分析的数据结果
     * @param iters
     * @param docSet
     * @throws IOException
     */
    public void saveIteratedModel(int iters) throws IOException {
  // lda.params lda.phi lda.theta lda.tassign lda.twords
  // lda.params
  String resPath = "D:\\result\\";
  String modelName = "lda_" + iters;
  StringBuffer buf=new StringBuffer();
  buf.append("alpha = " + alpha);
  buf.append("beta = " + beta);
  buf.append("topicNum = " + K);
  buf.append("docNum = " + M);
  buf.append("termNum = " + V);
  buf.append("iterations = " + iterations);
  buf.append("saveStep = " + saveStep);
  buf.append("beginSaveIters = " + beginSaveIters);
  BufferedWriter writer;
//  writer = new BufferedWriter(new FileWriter(resPath
//          + modelName + ".params.txt"));
//   writer.write(buf.toString());
//   writer.close();
//   
//  //两个隐含变量theta和phi分别表示第m个文档下的Topic分布和第k个Topic下词的分布,
//  // lda.phi K*V
//  writer = new BufferedWriter(new FileWriter(resPath
//          + modelName + ".phi.txt"));
//  for (int i = 0; i < K; i++) {
//      for (int j = 0; j < V; j++) {
//          writer.write("topic-word="+phi[i][j] + "\t");
//      }
//      writer.write("\n");
//  }
//  writer.close();
//  // lda.theta M*K
//  writer = new BufferedWriter(new FileWriter(resPath + modelName
//          + ".theta.txt"));
//  for (int i = 0; i < M; i++) {
//      for (int j = 0; j < K; j++) {
//          writer.write("doc-topic="+theta[i][j] + "\t");
//      }
//      writer.write("\n");
//  }
//  writer.close();
//  
//  // doc[m][n]每个文本中每个词在字典indexToTermMap中的序号
//  // z[m][n]每个文本的每个词对应的topic的编号
//  writer = new BufferedWriter(new FileWriter(resPath + modelName
//          + ".wordIndex2topicIndex.txt"));
//  for (int m = 0; m < M; m++) {
//      for (int n = 0; n < doc[m].length; n++) {
//          writer.write("doc[m][word]_index="+doc[m][n] + ":" +"z[m][word]_topicIndex="+ z[m][n] + "\t");
//      }
//      writer.write("\n");
//  }
//  writer.close();
  
  // lda.twords phi[][] K*V
  // 每个topic 前20个 出现概率高的,即 phi[i]大的
  writer = new BufferedWriter(new FileWriter(resPath + modelName
          + ".topic_words.txt"));
  int topNum = 20; // Find the top 20 topic words in each topic
  for (int i = 0; i < K; i++) {
      List<Integer> tWordsIndexArray = new ArrayList<Integer>();//topic的word的编号
      for (int j = 0; j < V; j++) {
          tWordsIndexArray.add(new Integer(j));
      }
      Collections.sort(tWordsIndexArray,
              new LdaModel.ArrayDoubleComparator(phi[i]));//按phi[i],即出现概率大的
      writer.write("topic " + i + ":\t");
      for (int t = 0; t < topNum; t++) {
//          writer.write(docSet.getTerms().get(tWordsIndexArray.get(t))
//                  + " " + phi[i][tWordsIndexArray.get(t)] + " ;\t");
       writer.write(docSet.getTerms().get(tWordsIndexArray.get(t))+" ");
      }
      writer.write("\n");
  }
  writer.close();
 }
    
    
    
    /**
     * @author jazywoo
     * 用于排序,比较phi[i],topic中词 出现概率高的
     */
    public class ArrayDoubleComparator implements Comparator<Integer> {
        private double[] sortProb; // Store probability of each word in topic k

        public ArrayDoubleComparator(double[] sortProb) {
            this.sortProb = sortProb;
        }

        @Override
        public int compare(Integer o1, Integer o2) {// Sort topic word index according to the probability of each word
            // in topic k
            if (sortProb[o1] > sortProb[o2])
                return -1;
            else if (sortProb[o1] < sortProb[o2])
                return 1;
            else
                return 0;
        }
    }
    
    
    public static class ModelParameter{
  public float alpha = 0.5f; //usual value is 50 / K
  public float beta = 0.1f;//usual value is 0.1
  public int topicNum = 10;
  public int iteration = 100;
  public int saveStep = 10;
  public int beginSaveIters = 80;
 }
}

package org.jazywoo.lda;

import java.io.IOException;

public class LDATest {

 /**
  * @param args
  * @throws IOException 
  */
 public static void main(String[] args) throws IOException {
  LdaModel.ModelParameter parameter=new LdaModel.ModelParameter();
  LdaModel ldaModel=new LdaModel(parameter);
  
  String path="D:\\zz";
  Corpus docSet=new Corpus();
  docSet.loadData(path);
  ldaModel.initModal(docSet);
  ldaModel.inferenceModel();
  ldaModel.saveIteratedModel(parameter.iteration);
 }

}






分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:7942463
帖子:1588486
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP