001package com.hfg.bio.seq.alignment.matrix;
002
003
004import java.util.ArrayList;
005import java.util.Collection;
006import java.util.Collections;
007import java.util.HashMap;
008import java.util.HashSet;
009import java.util.List;
010import java.util.Map;
011import java.util.Set;
012
013import com.hfg.bio.AminoAcid;
014import com.hfg.bio.AminoAcidFreqTable;
015import com.hfg.bio.seq.BioSequence;
016import com.hfg.bio.seq.BioSequenceType;
017import com.hfg.bio.seq.PositionalFrequencyMatrix;
018import com.hfg.bio.seq.Protein;
019import com.hfg.bio.seq.alignment.MultipleSequenceAlignment;
020import com.hfg.math.MathUtil;
021import com.hfg.util.StringBuilderPlus;
022import com.hfg.util.StringUtil;
023import com.hfg.util.collection.OrderedSet;
024import com.hfg.xml.XMLTag;
025
026//------------------------------------------------------------------------------
027/**
028 Container for a Position-Specific Scoring Matrix (PSSM). Also known as a
029 Position-Weighted Matrix.
030 <div>
031   See <a href='http://en.wikipedia.org/wiki/Position_weight_matrix'>http://en.wikipedia.org/wiki/Position_weight_matrix</a>
032 </div>
033 @author J. Alex Taylor, hairyfatguy.com
034 */
035//------------------------------------------------------------------------------
036// com.hfg Library
037//
038// This library is free software; you can redistribute it and/or
039// modify it under the terms of the GNU Lesser General Public
040// License as published by the Free Software Foundation; either
041// version 2.1 of the License, or (at your option) any later version.
042//
043// This library is distributed in the hope that it will be useful,
044// but WITHOUT ANY WARRANTY; without even the implied warranty of
045// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
046// Lesser General Public License for more details.
047//
048// You should have received a copy of the GNU Lesser General Public
049// License along with this library; if not, write to the Free Software
050// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
051//
052// J. Alex Taylor, President, Founder, CEO, COO, CFO, OOPS hairyfatguy.com
053// jataylor@hairyfatguy.com
054//------------------------------------------------------------------------------
055
056public class PSSM
057{
058
059   public enum Flag
060   {
061      CASE_SENSITIVE,
062      LOG2_WEIGHTED
063   }
064
065   private String mName;
066   private BioSequenceType mBioSequenceType;
067
068   private List<Integer> mPositions;
069   private int mMinPosition;
070   private int mMaxPosition;
071   private List<Character> mResidues;
072   private String mResiduesString;
073   private Float[][] mMatrixArray;
074
075   private List<Float> mPositionalGapOpenScore;
076   private List<Float> mPositionalGapExtScore;
077
078   private float  mDefaultScore = 0.0f;
079   private char   mUnknownResidue;
080
081   private Set<Flag> mFlags = new HashSet<>(4);
082
083   // Cached values
084   private Integer mLength;
085   private boolean mCaseSensitive;
086
087   private static final String XML_PSSM     = "PSSM";
088   private static final String XML_NAME_ATT = "name";
089   private static final String XML_RESIDUES = "Residues";
090   private static final String XML_MATRIX_ROW = "MatrixRow";
091   private static final String XML_POSITION_ATT = "position";
092
093   //###########################################################################
094   // CONSTRUCTORS
095   //###########################################################################
096
097   //---------------------------------------------------------------------------
098   /**
099    Constructs a PSSM from a multiple sequence alignment.
100    @param inMSA source multiple sequence alignment
101    @param inFlags flag(s) to apply to the PSSM
102    */
103   public PSSM(MultipleSequenceAlignment inMSA, Flag... inFlags)
104   {
105      this(inMSA, null, inFlags);
106   }
107
108   //---------------------------------------------------------------------------
109   /**
110    Constructs a PSSM from a multiple sequence alignment.
111    @param inMSA source multiple sequence alignment
112    @param inBackgroundAAFreq optional background amino acid frequency table. If null,
113                              an equal residue probability is used.
114    @param inFlags flag(s) to apply to the PSSM
115    */
116   public PSSM(MultipleSequenceAlignment inMSA, AminoAcidFreqTable inBackgroundAAFreq, Flag... inFlags)
117   {
118      // Flags must be set first
119      setFlags(inFlags);
120
121      mBioSequenceType = inMSA.getBioSequenceType();
122      if (BioSequenceType.NUCLEIC_ACID.equals(mBioSequenceType))
123      {
124         mUnknownResidue = 'N';
125      }
126      else
127      {
128         mUnknownResidue = 'X';
129      }
130
131      mResidues = new OrderedSet<>(30);
132
133      PositionalFrequencyMatrix<Protein> freqMatrix;
134      if (mCaseSensitive)
135      {
136         freqMatrix = inMSA.getPositionFreqMatrix(new PositionalFrequencyMatrix.Flag[] {PositionalFrequencyMatrix.Flag.CASE_SENSITIVE});
137      }
138      else
139      {
140         freqMatrix = inMSA.getPositionFreqMatrix();
141      }
142
143      if (freqMatrix != null)
144      {
145         float  totalPsuedocounts = (float) Math.sqrt(inMSA.size());
146
147         for (Character residue : freqMatrix.getResidueKeys())
148         {
149            mResidues.add(residue);
150         }
151
152         Map<Character, Float> backgroundResidueFreqMap = new HashMap<>(30);
153         if (inBackgroundAAFreq != null)
154         {
155            for (AminoAcid aa : inBackgroundAAFreq.keySet())
156            {
157               mResidues.add(aa.getOneLetterCode());
158
159               if (aa.getOneLetterCode() != mUnknownResidue)
160               {
161                  backgroundResidueFreqMap.put(aa.getOneLetterCode(), inBackgroundAAFreq.get(aa));
162               }
163            }
164         }
165         else
166         {
167            for (Character residue : freqMatrix.getResidueKeys())
168            {
169               if (residue != mUnknownResidue)
170               {
171                  backgroundResidueFreqMap.put(residue, 1f);
172               }
173            }
174         }
175
176         mResiduesString = StringUtil.join(mResidues, "");
177
178         boolean log2Weighted = mFlags.contains(Flag.LOG2_WEIGHTED);
179
180         // Determine the number of rows and columns
181         mPositions = new ArrayList<>(freqMatrix.getPositionKeys());
182         Collections.sort(mPositions);
183         mMinPosition = mPositions.get(0);
184         mMaxPosition = mPositions.get(mPositions.size() - 1);
185         int numRows = mMaxPosition - mMinPosition + 1;
186         int numCols = mResidues.size();
187         // Create the 2D array
188         mMatrixArray = new Float[numRows][numCols];
189
190
191         mPositionalGapOpenScore = new ArrayList<>();
192         mPositionalGapExtScore  = new ArrayList<>();
193
194         for (Integer position : freqMatrix.getPositionKeys())
195         {
196            for (Character residue : mResidues)
197            {
198               if (residue != mUnknownResidue)
199               {
200                  Float bgFreq = backgroundResidueFreqMap.get(residue);
201                  // Avoid divide by zero
202                  if (null == bgFreq
203                        || 0.0f == bgFreq)
204                  {
205                     bgFreq = 0.0001f;
206                  }
207
208//                  float weightedScore = ((float) freqMatrix.getCount(residue, position) + (totalPsuedocounts * bgFreq)) / (freqMatrix.getPositionTotal(position) + totalPsuedocounts);
209                  float weightedScore = ((float) freqMatrix.getCount(residue, position) + (totalPsuedocounts * bgFreq)) / (inMSA.size() + totalPsuedocounts);
210
211                  if (log2Weighted)
212                  {
213                     weightedScore = (float) MathUtil.log2(weightedScore);
214                  }
215
216                  int residueIdx = mResiduesString.indexOf(residue);
217                  mMatrixArray[position - mMinPosition][residueIdx] = weightedScore * 10;
218               }
219            }
220
221            float weightedScore = 1 - ((float) freqMatrix.getGapOpenCount(position) + (totalPsuedocounts * 0.0001f)) / (inMSA.size() + totalPsuedocounts);
222            if (log2Weighted)
223            {
224               weightedScore = (float) MathUtil.log2(weightedScore);
225            }
226            mPositionalGapOpenScore.add(weightedScore);
227
228            weightedScore = 1 - ((float) freqMatrix.getGapExtCount(position) + (totalPsuedocounts * 0.0001f)) / (inMSA.size() + totalPsuedocounts);
229            if (log2Weighted)
230            {
231               weightedScore = (float) MathUtil.log2(weightedScore);
232            }
233            mPositionalGapExtScore.add(weightedScore);
234         }
235      }
236   }
237
238   //---------------------------------------------------------------------------
239   public PSSM(XMLTag inXMLTag)
240   {
241      inXMLTag.verifyTagName(XML_PSSM);
242      setName(inXMLTag.getAttributeValue(XML_NAME_ATT));
243
244      XMLTag residuesTag = inXMLTag.getRequiredSubtagByName(XML_RESIDUES);
245      mResidues = new ArrayList<>(30);
246      for (char residue : residuesTag.getContent().toCharArray())
247      {
248         mResidues.add(residue);
249      }
250
251      mResiduesString = StringUtil.join(mResidues, "");
252
253      List<XMLTag> matrixRowTags = inXMLTag.getSubtagsByName(XML_MATRIX_ROW);
254      List<Integer> mPositions = new ArrayList<>(matrixRowTags.size());
255      for (XMLTag rowTag : matrixRowTags)
256      {
257         mPositions.add(Integer.parseInt(rowTag.getAttributeValue(XML_POSITION_ATT)));
258      }
259      Collections.sort(mPositions);
260      mMinPosition = mPositions.get(0);
261      mMaxPosition = mPositions.get(mPositions.size() - 1);
262
263      // Determine the number of rows and columns
264      int numRows = mMaxPosition - mMinPosition + 1;
265      int numCols = mResidues.size();
266      // Create the 2D array
267      mMatrixArray = new Float[numRows][numCols];
268
269      for (XMLTag matrixRowTag : matrixRowTags)
270      {
271         int position = Integer.parseInt(matrixRowTag.getAttributeValue(XML_POSITION_ATT));
272         String[] values = matrixRowTag.getContent().split("\\s+");
273         for (int j = 0; j < mResidues.size(); j++)
274         {
275            mMatrixArray[position - mMinPosition][j] = Float.parseFloat(values[j]);
276         }
277      }
278   }
279
280   //###########################################################################
281   // PUBLIC METHODS
282   //###########################################################################
283
284   //---------------------------------------------------------------------------
285   public float score(int inPosition, char inResidue)
286   {
287      char residue = inResidue;
288      if (! mCaseSensitive)
289      {
290         residue = Character.toUpperCase(residue);
291      }
292
293      float score = mDefaultScore;
294
295      if (residue != mUnknownResidue
296          && inPosition >= mMinPosition
297          && inPosition <= mMaxPosition)
298      {
299//         score = mMatrix.get(inPosition, residue);
300         int residueIdx = mResiduesString.indexOf(residue);
301         if (residueIdx >= 0)
302         {
303            score = mMatrixArray[inPosition - mMinPosition][residueIdx];
304         }
305      }
306
307      return score;
308   }
309
310   //---------------------------------------------------------------------------
311   public float getGapOpenScore(int inPosition)
312   {
313      return mPositionalGapOpenScore.get(inPosition - 1);
314   }
315
316   //---------------------------------------------------------------------------
317   // Used for eaking out a little more performance vs. repeated calls to getGapOpenScore().
318   public float[] getGapOpenScores()
319   {
320      float[] values = new float[mPositionalGapOpenScore.size()];
321      int i = 0;
322      for (Float value : mPositionalGapOpenScore)
323      {
324         values[i++] = (value != null ? value : Float.NaN);
325      }
326
327      return values;
328   }
329
330   //---------------------------------------------------------------------------
331   public float getGapExtScore(int inPosition)
332   {
333      return mPositionalGapExtScore.get(inPosition - 1);
334   }
335
336   //---------------------------------------------------------------------------
337   public String showScoring(Collection<BioSequence> inSeqs)
338   {
339      StringBuilderPlus buffer = new StringBuilderPlus();
340
341      for (Integer position : mPositions)
342      {
343         buffer.append(String.format("%3d. ", position));
344
345         for (BioSequence seq : inSeqs)
346         {
347            char residue = seq.getSequence().charAt(position - 1);
348            buffer.append(String.format("%c %.1f   ", residue, score(position, residue)));
349         }
350
351         buffer.appendln();
352      }
353
354      return buffer.toString();
355   }
356
357   //---------------------------------------------------------------------------
358   public String getName()
359   {
360      return mName;
361   }
362
363   //---------------------------------------------------------------------------
364   public PSSM setName(String inValue)
365   {
366      mName = inValue;
367      return this;
368   }
369
370   //---------------------------------------------------------------------------
371   public int length()
372   {
373      if (null == mLength)
374      {
375         int length = 0;
376         if (mMatrixArray != null)
377         {
378            // Since not all positions might be present in the matrix, return the max position value
379            length = mMaxPosition;
380         }
381
382         mLength = length;
383      }
384
385      return mLength;
386   }
387
388   //---------------------------------------------------------------------------
389   public XMLTag toXMLTag()
390   {
391      XMLTag rootTag = new XMLTag(XML_PSSM);
392      if (StringUtil.isSet(getName()))
393      {
394         rootTag.setAttribute(XML_NAME_ATT, getName());
395      }
396
397      rootTag.addSubtag(new XMLTag(XML_RESIDUES).setContent(StringUtil.join(mResidues, "")));
398
399      for (Integer position : mPositions)
400      {
401         XMLTag matrixRowTag = new XMLTag(XML_MATRIX_ROW);
402         matrixRowTag.setAttribute(XML_POSITION_ATT, position);
403         StringBuilderPlus buffer = new StringBuilderPlus().setDelimiter(" ");
404         for (Character residue : mResidues)
405         {
406            int residueIdx = mResiduesString.indexOf(residue);
407            Float value = mMatrixArray[position - mMinPosition][residueIdx];
408            buffer.delimitedAppend(String.format("%.2f", value));
409         }
410         matrixRowTag.setContent(buffer);
411
412         rootTag.addSubtag(matrixRowTag);
413      }
414
415      return rootTag;
416   }
417
418   //###########################################################################
419   // PRIVATE METHODS
420   //###########################################################################
421
422   //--------------------------------------------------------------------------
423   private void setFlags(Flag[] inFlags)
424   {
425      mFlags.clear();
426
427      mCaseSensitive = false;
428
429      if (inFlags != null)
430      {
431         for (Flag flag : inFlags)
432         {
433            mFlags.add(flag);
434         }
435
436         if (mFlags.contains(Flag.CASE_SENSITIVE))
437         {
438            mCaseSensitive = true;
439         }
440      }
441   }
442}