001 package org.maltparser.parser.guide.instance;
002
003 import java.io.IOException;
004 import java.lang.reflect.Constructor;
005 import java.lang.reflect.InvocationTargetException;
006 import java.util.ArrayList;
007 import java.util.Formatter;
008
009 import org.maltparser.core.exception.MaltChainedException;
010 import org.maltparser.core.feature.FeatureVector;
011 import org.maltparser.core.feature.function.FeatureFunction;
012 import org.maltparser.core.feature.function.Modifiable;
013 import org.maltparser.core.syntaxgraph.DependencyStructure;
014 import org.maltparser.ml.LearningMethod;
015 import org.maltparser.parser.guide.ClassifierGuide;
016 import org.maltparser.parser.guide.GuideException;
017 import org.maltparser.parser.guide.Model;
018 import org.maltparser.parser.history.action.SingleDecision;
019
020
021 /**
022
023 @author Johan Hall
024 @since 1.0
025 */
026 public class AtomicModel implements InstanceModel {
027 private Model parent;
028 private String modelName;
029 private FeatureVector featureVector;
030 private int index;
031 private int frequency = 0;
032 private LearningMethod method;
033
034
035 /**
036 * Constructs an atomic model.
037 *
038 * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model
039 * or the master divide model) and n is number of divide models.
040 * @param features the feature vector used by the atomic model.
041 * @param parent the parent guide model.
042 * @throws MaltChainedException
043 */
044 public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException {
045 setParent(parent);
046 setIndex(index);
047 if (index == -1) {
048 setModelName(parent.getModelName()+".");
049 } else {
050 setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+".");
051 }
052 setFeatures(features);
053 setFrequency(0);
054 initMethod();
055 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) {
056 try {
057 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString());
058 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush();
059 } catch (IOException e) {
060 throw new GuideException("Could not write learner settings to the information file. ", e);
061 }
062 }
063 }
064
065 public void addInstance(SingleDecision decision) throws MaltChainedException {
066 try {
067 method.addInstance(decision, featureVector);
068 } catch (NullPointerException e) {
069 throw new GuideException("The learner cannot be found. ", e);
070 }
071 }
072
073
074 public void noMoreInstances() throws MaltChainedException {
075 try {
076 method.noMoreInstances();
077 } catch (NullPointerException e) {
078 throw new GuideException("The learner cannot be found. ", e);
079 }
080 }
081
082 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
083 try {
084 method.finalizeSentence(dependencyGraph);
085 } catch (NullPointerException e) {
086 throw new GuideException("The learner cannot be found. ", e);
087 }
088 }
089
090 public boolean predict(SingleDecision decision) throws MaltChainedException {
091 try {
092 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
093 throw new GuideException("Cannot predict during batch training. ");
094 }
095 return method.predict(featureVector, decision);
096 } catch (NullPointerException e) {
097 throw new GuideException("The learner cannot be found. ", e);
098 }
099 }
100
101 public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException {
102 try {
103 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
104 throw new GuideException("Cannot predict during batch training. ");
105 }
106 if (method.predict(featureVector, decision)) {
107 return featureVector;
108 }
109 return null;
110 } catch (NullPointerException e) {
111 throw new GuideException("The learner cannot be found. ", e);
112 }
113 }
114
115 public FeatureVector extract() throws MaltChainedException {
116 return featureVector;
117 }
118
119 public void terminate() throws MaltChainedException {
120 if (method != null) {
121 method.terminate();
122 method = null;
123 }
124 featureVector = null;
125 parent = null;
126 }
127
128 /**
129 * Moves all instance from this atomic model into the destination atomic model and add the divide feature.
130 * This method is used by the feature divide model to sum up all model below a certain threshold.
131 *
132 * @param model the destination atomic model
133 * @param divideFeature the divide feature
134 * @param divideFeatureIndexVector the divide feature index vector
135 * @throws MaltChainedException
136 */
137 public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
138 if (method == null) {
139 throw new GuideException("The learner cannot be found. ");
140 } else if (model == null) {
141 throw new GuideException("The guide model cannot be found. ");
142 } else if (divideFeature == null) {
143 throw new GuideException("The divide feature cannot be found. ");
144 } else if (divideFeatureIndexVector == null) {
145 throw new GuideException("The divide feature index vector cannot be found. ");
146 }
147 ((Modifiable)divideFeature).setFeatureValue(index);
148 method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
149 method.terminate();
150 method = null;
151 }
152
153 /**
154 * Invokes the train() of the learning method
155 *
156 * @throws MaltChainedException
157 */
158 public void train() throws MaltChainedException {
159 try {
160 method.train(featureVector);
161 method.terminate();
162 method = null;
163 } catch (NullPointerException e) {
164 throw new GuideException("The learner cannot be found. ", e);
165 }
166 }
167
168 /**
169 * Initialize the learning method according to the option --learner-method.
170 *
171 * @throws MaltChainedException
172 */
173 public void initMethod() throws MaltChainedException {
174 Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
175 // if (clazz == org.maltparser.ml.libsvm.Libsvm.class && (Boolean)getGuide().getConfiguration().getOptionValue("malt0.4", "behavior") == true) {
176 // try {
177 // clazz = Class.forName("org.maltparser.ml.libsvm.malt04.LibsvmMalt04");
178 // } catch (ClassNotFoundException e) {
179 // throw new GuideException("Could not find the class 'org.maltparser.ml.libsvm.malt04.LibsvmMalt04'. ", e);
180 // }
181 // }
182 Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
183 Object[] arguments = new Object[2];
184 arguments[0] = this;
185 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
186 arguments[1] = LearningMethod.CLASSIFY;
187 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
188 arguments[1] = LearningMethod.BATCH;
189 }
190
191 try {
192 Constructor<?> constructor = clazz.getConstructor(argTypes);
193 this.method = (LearningMethod)constructor.newInstance(arguments);
194 } catch (NoSuchMethodException e) {
195 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
196 } catch (InstantiationException e) {
197 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
198 } catch (IllegalAccessException e) {
199 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
200 } catch (InvocationTargetException e) {
201 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
202 }
203 }
204
205
206
207 /**
208 * Returns the parent guide model
209 *
210 * @return the parent guide model
211 */
212 public Model getParent() throws MaltChainedException {
213 if (parent == null) {
214 throw new GuideException("The atomic model can only be used by a parent model. ");
215 }
216 return parent;
217 }
218
219 /**
220 * Sets the parent guide model
221 *
222 * @param parent the parent guide model
223 */
224 protected void setParent(Model parent) {
225 this.parent = parent;
226 }
227
228 public String getModelName() {
229 return modelName;
230 }
231
232 /**
233 * Sets the name of the atomic model
234 *
235 * @param modelName the name of the atomic model
236 */
237 protected void setModelName(String modelName) {
238 this.modelName = modelName;
239 }
240
241 /**
242 * Returns the feature vector used by this atomic model
243 *
244 * @return a feature vector object
245 */
246 public FeatureVector getFeatures() {
247 return featureVector;
248 }
249
250 /**
251 * Sets the feature vector used by the atomic model.
252 *
253 * @param features a feature vector object
254 */
255 protected void setFeatures(FeatureVector features) {
256 this.featureVector = features;
257 }
258
259 public ClassifierGuide getGuide() {
260 return parent.getGuide();
261 }
262
263 /**
264 * Returns the index of the atomic model
265 *
266 * @return the index of the atomic model
267 */
268 public int getIndex() {
269 return index;
270 }
271
272 /**
273 * Sets the index of the model (-1..n), where -1 is a special value.
274 *
275 * @param index index value (-1..n) of the atomic model
276 */
277 protected void setIndex(int index) {
278 this.index = index;
279 }
280
281 /**
282 * Returns the frequency (number of instances)
283 *
284 * @return the frequency (number of instances)
285 */
286 public int getFrequency() {
287 return frequency;
288 }
289
290 /**
291 * Increase the frequency by 1
292 */
293 public void increaseFrequency() {
294 if (parent instanceof InstanceModel) {
295 ((InstanceModel)parent).increaseFrequency();
296 }
297 frequency++;
298 }
299
300 public void decreaseFrequency() {
301 if (parent instanceof InstanceModel) {
302 ((InstanceModel)parent).decreaseFrequency();
303 }
304 frequency--;
305 }
306 /**
307 * Sets the frequency (number of instances)
308 *
309 * @param frequency (number of instances)
310 */
311 protected void setFrequency(int frequency) {
312 this.frequency = frequency;
313 }
314
315 /**
316 * Returns a learner object
317 *
318 * @return a learner object
319 */
320 public LearningMethod getMethod() {
321 return method;
322 }
323
324
325 /* (non-Javadoc)
326 * @see java.lang.Object#toString()
327 */
328 public String toString() {
329 final StringBuilder sb = new StringBuilder();
330 sb.append(method.toString());
331 return sb.toString();
332 }
333 }