Introduction:
Solr has been a revolution in the search world with its major implementations. A mahout is an exciting tool for machine learning work. In this article, I am going to cover the integration of Solr and Mahout for the classification process.
Classification:
The classification here is the process of categorizing content into a pre-defined set of categories. The classification process depends on a model created from training sets. I will cover mahout classification in my next blog.
Implementation:
I am going to hook into Solr update process, call the mahout classifier and add the category field based on the result from the classifier. So every document indexing will have its category automatically assigned. Add the following configuration to solrconfig.xml.
Solr Config:
<updateRequestProcessorChain name="mlinterceptor" default="true">
<processor class="org.apache.solr.update.processor.ext.CategorizeDocumentFactory">
<str name="inputField">content</str>
<str name="outputField">category</str>
<str name="defaultCategory">Others</str>
<str name="model">/home/selvam/bayes-model</str>
</processor>
<processor class="solr.RunUpdateProcessorFactory"/>
<processor class="solr.LogUpdateProcessorFactory"/>
</updateRequestProcessorChain>
<requestHandler name="/update" class="solr.XmlUpdateRequestHandler">
<lst name="defaults">
<str name="update.processor">mlinterceptor</str>
</lst>
</requestHandler>
org.apache.solr.update.processor.ext.CategorizeDocumentFactory is our custom java code compiled into jar. Place this jar in solr/lib directory.
Code:
package org.apache.solr.update.processor.ext;
import java.io.IOException;
import java.util.ArrayList;
import java.io.StringReader;
import java.io.File;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.SolrInputField;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.update.AddUpdateCommand;
import org.apache.solr.update.processor.UpdateRequestProcessor;
import org.apache.solr.update.processor.UpdateRequestProcessorFactory;
import org.apache.solr.common.util.NamedList;
import org.apache.lucene.util.Version;
import org.apache.lucene.util.Version.*;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.TermAttribute;
//mahout
import org.apache.mahout.classifier.bayes.model.ClassifierContext;
import org.apache.mahout.classifier.bayes.datastore.InMemoryBayesDatastore;
import org.apache.mahout.classifier.bayes.interfaces.Datastore;
import org.apache.mahout.classifier.bayes.interfaces.Algorithm;
import org.apache.mahout.classifier.bayes.algorithm.BayesAlgorithm;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.bayes.common.BayesParameters;
public class CategorizeDocumentFactory extends UpdateRequestProcessorFactory
{
SolrParams params;
ClassifierContext ctx;
public void init( NamedList args )
{
params = SolrParams.toSolrParams((NamedList) args);
BayesParameters p = new BayesParameters();
String modelPath = params.get("model");
p.setBasePath(modelPath);
InMemoryBayesDatastore ds = new InMemoryBayesDatastore(p);
Algorithm alg = new BayesAlgorithm();
ClassifierContext ctx = new ClassifierContext(alg,ds);
try {
ctx.initialize();
}
catch(Exception e1){
}
}
@Override
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next)
{
return new CategorizeDocument(next);
}
public class CategorizeDocument extends UpdateRequestProcessor
{
public CategorizeDocument( UpdateRequestProcessor next) {
super( next );
}
@Override
public void processAdd(AddUpdateCommand cmd) throws IOException {
try{
SolrInputDocument doc = cmd.getSolrInputDocument();
String inputField = params.get("inputField");
String outputField = params.get("outputField");
String input = (String) doc.getFieldValue(inputField);
ArrayList<String> tokenList = new ArrayList<String>(256);
StandardAnalyzer analyzer = new StandardAnalyzer(Version.LUCENE_30);
TokenStream ts = analyzer.tokenStream(inputField, new StringReader(input));
while (ts.incrementToken()) {
tokenList.add(ts.getAttribute(TermAttribute.class).toString());
}
String[] tokens = tokenList.toArray(new String[tokenList.size()]);
//Call the mahout classification process
ClassifierResult result = ctx.classifyDocument(tokens, "Others");
if (result != null && result.getLabel() != "") {
doc.addField(outputField, result.getLabel());
}
}
catch(IOException e1){
e1.printStackTrace();
}
catch(Exception e){
e.printStackTrace();
}
super.processAdd(cmd);
}
}
}
When you start solr it might take a little more time as the classification model is loaded into memory. Don't worry it is only once loaded and kept in memory, so your classification process will be lightning-fast :)
Starting Solr:
If your model is too big, then you might get a Java heap error. In that case, you can start solr as,
java -jar -XX:+UseConcMarkSweepGC start.jar