001package com.hfg.bio.seq;
002
003import com.hfg.bio.seq.alignment.MultipleSequenceAlignment;
004import com.hfg.exception.ProgrammingException;
005import com.hfg.math.Counter;
006import com.hfg.util.collection.OrderedSet;
007import com.hfg.util.collection.SparseMatrix;
008
009import java.util.ArrayList;
010import java.util.Collections;
011import java.util.HashMap;
012import java.util.HashSet;
013import java.util.List;
014import java.util.Map;
015import java.util.Set;
016
017//------------------------------------------------------------------------------
018/**
019 Positional frequency matrix.
020 <div>
021 @author J. Alex Taylor, hairyfatguy.com
022 </div>
023 */
024//------------------------------------------------------------------------------
025// com.hfg Library
026//
027// This library is free software; you can redistribute it and/or
028// modify it under the terms of the GNU Lesser General Public
029// License as published by the Free Software Foundation; either
030// version 2.1 of the License, or (at your option) any later version.
031//
032// This library is distributed in the hope that it will be useful,
033// but WITHOUT ANY WARRANTY; without even the implied warranty of
034// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
035// Lesser General Public License for more details.
036//
037// You should have received a copy of the GNU Lesser General Public
038// License along with this library; if not, write to the Free Software
039// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
040//
041// J. Alex Taylor, President, Founder, CEO, COO, CFO, OOPS hairyfatguy.com
042// jataylor@hairyfatguy.com
043//------------------------------------------------------------------------------
044
045public class PositionalFrequencyMatrix<T extends BioSequence> implements Cloneable
046{
047   public enum Flag
048   {
049      CASE_SENSITIVE
050   }
051
052   private Set<Flag> mFlags = new HashSet<>(4);
053   private SparseMatrix<Character, Integer, Counter> mMatrix;
054   private Map<Integer, Counter> mPositionTotalCountMap;
055   private Counter mSequenceCounter = new Counter();
056   private Map<Integer, Counter> mPositionGapOpenCountMap;
057   private Map<Integer, Counter> mPositionGapExtCountMap;
058
059   //##########################################################################
060   // CONSTRUCTORS
061   //##########################################################################
062
063   //--------------------------------------------------------------------------
064   public PositionalFrequencyMatrix(MultipleSequenceAlignment<T> inMSA)
065   {
066      this(inMSA, (Flag)null);
067   }
068
069   //--------------------------------------------------------------------------
070   public PositionalFrequencyMatrix(MultipleSequenceAlignment<T> inMSA, Flag... inFlags)
071
072   {
073      // Flags must be set first
074      setFlags(inFlags);
075
076      if (inMSA != null)
077      {
078         // Initialize values to zero
079         Set<Character> residueSet;
080         if (BioSequenceType.PROTEIN.equals(inMSA.getBioSequenceType()))
081         {
082            residueSet = ((Protein)inMSA.getSequences().get(0)).getAminoAcidSet().getResidueChars();
083         }
084         else
085         {
086            residueSet = new HashSet<>(30);
087            for (char nucleotide : "ATGCN".toCharArray())
088            {
089               residueSet.add(nucleotide);
090            }
091         }
092
093         init(residueSet, inMSA.length());
094
095         for (T seq : inMSA.getSequences())
096         {
097            addSequence(seq);
098         }
099      }
100   }
101
102   //##########################################################################
103   // PUBLIC METHODS
104   //##########################################################################
105
106   //---------------------------------------------------------------------------
107   @Override
108   public PositionalFrequencyMatrix<T> clone()
109   {
110      PositionalFrequencyMatrix<T> cloneObj;
111      try
112      {
113         cloneObj = (PositionalFrequencyMatrix<T>) super.clone();
114      }
115      catch (CloneNotSupportedException e)
116      {
117         throw new ProgrammingException(e);
118      }
119
120      if (mMatrix != null)
121      {
122         cloneObj.mMatrix = mMatrix.clone();
123      }
124
125      if (mPositionTotalCountMap != null)
126      {
127         for (Integer key : mPositionTotalCountMap.keySet())
128         {
129            cloneObj.mPositionTotalCountMap.put(key, mPositionTotalCountMap.get(key).clone());
130         }
131      }
132
133      return cloneObj;
134   }
135
136   //--------------------------------------------------------------------------
137   public void addSequence(T inSequence)
138   {
139      if (inSequence.length() != mMatrix.colKeySet().size())
140      {
141         throw new RuntimeException("The added sequence (" + inSequence.getID() + ") is not the same length as the others for this " + getClass().getSimpleName() + "!");
142      }
143
144      String sequenceString = inSequence.getSequence();
145      if (! mFlags.contains(Flag.CASE_SENSITIVE))
146      {
147         sequenceString = sequenceString.toUpperCase();
148      }
149
150      int position = 1;
151      char prevResidue = ' ';
152      for (char residue : sequenceString.toCharArray())
153      {
154         if ('-' == residue)
155         {
156            Map<Integer, Counter> gapMap;
157            if (prevResidue != '-')
158            {
159               gapMap = mPositionGapOpenCountMap;
160            }
161            else
162            {
163               gapMap = mPositionGapExtCountMap;
164            }
165
166            Counter counter = gapMap.get(position);
167            if (null == counter)
168            {
169               counter = new Counter();
170               gapMap.put(position, counter);
171            }
172
173            counter.increment();
174         }
175         else
176         {
177            mPositionTotalCountMap.get(position).increment();
178            Counter counter = mMatrix.get(residue, position);
179            if (null == counter)
180            {
181               counter = new Counter();
182               mMatrix.put(residue, position, counter);
183            }
184
185            counter.increment();
186         }
187
188         position++;
189         prevResidue = residue;
190      }
191
192      mSequenceCounter.increment();
193   }
194
195   //--------------------------------------------------------------------------
196   public Set<Character> getResidueKeys()
197   {
198      return mMatrix.rowKeySet();
199   }
200
201   //--------------------------------------------------------------------------
202   public Set<Integer> getPositionKeys()
203   {
204      return mMatrix.colKeySet();
205   }
206
207   //--------------------------------------------------------------------------
208   public void insertGapAtPosition(int inPosition)
209   {
210      List<Integer> positionKeys = new ArrayList<>(getPositionKeys());
211      if (inPosition > 0
212          && inPosition <= positionKeys.get(positionKeys.size() - 1) + 1)
213      {
214         if (inPosition <= positionKeys.size())
215         {
216            for (int i = positionKeys.size(); i >= inPosition; i--)
217            {
218               mMatrix.putCol(i + 1, mMatrix.getCol(i));
219
220               mPositionTotalCountMap.put(i + 1, mPositionTotalCountMap.get(i));
221            }
222         }
223
224         // Put all zero values in the new column
225         for (Character residue : getResidueKeys())
226         {
227            mMatrix.put(residue, inPosition, new Counter());
228         }
229
230
231         int numGapsAtPrevPosition = (inPosition > 1 ? getGapCount(inPosition - 1) : 0);
232
233
234         Counter counter = mPositionGapOpenCountMap.get(inPosition);
235         if (null == counter)
236         {
237            counter = new Counter();
238            mPositionGapOpenCountMap.put(inPosition, counter);
239         }
240
241         counter.add(getPositionTotal(inPosition - 1));
242
243
244         counter = mPositionGapExtCountMap.get(inPosition);
245         if (null == counter)
246         {
247            counter = new Counter();
248            mPositionGapExtCountMap.put(inPosition, counter);
249         }
250
251         counter.add(numGapsAtPrevPosition);
252
253
254         mPositionTotalCountMap.put(inPosition, new Counter());
255      }
256   }
257
258   //--------------------------------------------------------------------------
259   public int getCount(Character inResidue, Integer inPosition)
260   {
261      int count = 0;
262
263      if (inResidue == '-')
264      {
265         count = getGapCount(inPosition);
266      }
267      else
268      {
269         Counter counter = mMatrix.get(inResidue, inPosition);
270         if (null == counter)
271         {
272            if (!getPositionKeys().contains(inPosition))
273            {
274               throw new RuntimeException("Position out of bounds: " + inPosition + "!");
275            }
276            else if (!getResidueKeys().contains(inResidue))
277            {
278               throw new RuntimeException("Invalid residue: " + inResidue + "!");
279            }
280         }
281
282         count = counter.intValue();
283      }
284
285      return count;
286   }
287
288   //--------------------------------------------------------------------------
289   public int getGapCount(Integer inPosition)
290   {
291      return getGapOpenCount(inPosition) + getGapExtCount(inPosition);
292   }
293
294   //--------------------------------------------------------------------------
295   public int getGapOpenCount(Integer inPosition)
296   {
297      Counter counter = mPositionGapOpenCountMap.get(inPosition);
298      return (counter != null ? counter.intValue() : 0);
299   }
300
301   //--------------------------------------------------------------------------
302   public int getGapExtCount(Integer inPosition)
303   {
304      Counter counter = mPositionGapExtCountMap.get(inPosition);
305      return (counter != null ? counter.intValue() : 0);
306   }
307
308
309   //--------------------------------------------------------------------------
310   /**
311    Returns the residue with the highest frequency at the specified position or
312    multiple residues if they share the highest frequency value.
313    @param inPosition the (1-based) position to evaluate
314    @return the residue with the highest frequency at the specified position or
315            multiple residues if they share the highest frequency value
316    */
317   public Set<Character> getHighestFreqResidues(Integer inPosition)
318   {
319      Map<Character, Counter> positionMap = mMatrix.getCol(inPosition);
320      int maxCount = 0;
321      for (Character key : positionMap.keySet())
322      {
323         Counter counter = positionMap.get(key);
324         if (counter.intValue() > maxCount)
325         {
326            maxCount = counter.intValue();
327         }
328      }
329
330      Set<Character> resultSet = new OrderedSet<>(10);
331      for (Character key : positionMap.keySet())
332      {
333         if (positionMap.get(key).intValue() == maxCount)
334         {
335            resultSet.add(key);
336         }
337      }
338
339      return resultSet;
340   }
341
342   //--------------------------------------------------------------------------
343   public Set<Character> getResidues(Integer inPosition)
344   {
345      return getResidues(inPosition, null);
346   }
347
348   //--------------------------------------------------------------------------
349   /**
350    * Retrieves the residues at the specified position that are represented above
351    * the specified minimum fraction.
352    * @param inPosition the aligned sequence position
353    * @param inMinFraction the minimum fraction for residues to return. Null returns
354    *                      residues that are present in at least one sequence.
355    * @return a Set of residue characters from the specified position
356    */
357   public Set<Character> getResidues(Integer inPosition, Float inMinFraction)
358   {
359      if (! getPositionKeys().contains(inPosition))
360      {
361         throw new RuntimeException("Position out of bounds: " + inPosition + "!");
362      }
363
364      Map<Character, Counter> residueCountMap = mMatrix.getCol(inPosition);
365      Set<Character> residues = new OrderedSet<>(30);
366      float total = mPositionTotalCountMap.get(inPosition).floatValue();
367      for (Character residue : residueCountMap.keySet())
368      {
369         if (inMinFraction != null)
370         {
371            if (residueCountMap.get(residue).intValue()/total > inMinFraction)
372            {
373               residues.add(residue);
374            }
375         }
376         else if (residueCountMap.get(residue).intValue() > 0)
377         {
378            residues.add(residue);
379         }
380      }
381
382      return residues;
383   }
384
385   //--------------------------------------------------------------------------
386   public float getFraction(Character inResidue, Integer inPosition)
387   {
388      int count = getCount(inResidue, inPosition);
389      float fraction;
390      if (inResidue == '-')
391      {
392         fraction = count / (float) (getPositionTotal(inPosition) + count);
393      }
394      else
395      {
396         fraction = count / (float) getPositionTotal(inPosition);
397      }
398
399      return fraction;
400   }
401
402   //--------------------------------------------------------------------------
403   public float getFractionIncludingGaps(Character inResidue, Integer inPosition)
404   {
405      return getCount(inResidue, inPosition) / (float) (getPositionTotal(inPosition) + getGapCount(inPosition));
406   }
407
408   //--------------------------------------------------------------------------
409   /**
410    Returns the number of sequences with a residue and not a gap at the specified position.
411    * @param inPosition the 1-based position in the matrix being queried
412    * @return the number of sequences with a residue and not a gap at the specified position
413    */
414   public int getPositionTotal(Integer inPosition)
415   {
416      if (! getPositionKeys().contains(inPosition))
417      {
418         throw new RuntimeException("Position out of bounds: " + inPosition + "!");
419      }
420
421      return mPositionTotalCountMap.get(inPosition).intValue();
422   }
423
424   //--------------------------------------------------------------------------
425   public Set<Flag> getFlags()
426   {
427      return Collections.unmodifiableSet(mFlags);
428   }
429
430   //##########################################################################
431   // PRIVATE METHODS
432   //##########################################################################
433
434   //--------------------------------------------------------------------------
435   private void setFlags(Flag[] inFlags)
436   {
437      mFlags.clear();
438
439      if (inFlags != null)
440      {
441         for (Flag flag : inFlags)
442         {
443            mFlags.add(flag);
444         }
445      }
446   }
447
448   //--------------------------------------------------------------------------
449   private void init(Set<Character> inResidueSet, int inLength)
450   {
451      // Initialize values to zero
452      Set<Character> residueSet = new HashSet<>(30);
453      for (Character residue : inResidueSet)
454      {
455         if (null == residue)
456         {
457            residue = 'X';
458         }
459
460         if (! mFlags.contains(Flag.CASE_SENSITIVE))
461         {
462            residue = Character.toUpperCase(residue);
463         }
464
465         residueSet.add(residue);
466      }
467
468      // Initialize the matrix and position total map with zeroed counters
469      mMatrix = new SparseMatrix<>(residueSet.size(), inLength);
470
471      mPositionTotalCountMap = new HashMap<>(inLength);
472
473      for (int position = 1; position <= inLength; position++)
474      {
475         mPositionTotalCountMap.put(position, new Counter());
476
477         for (char residue : residueSet)
478         {
479            mMatrix.put(residue, position, new Counter());
480         }
481      }
482
483      mPositionGapOpenCountMap = new HashMap<>(inLength);
484      mPositionGapExtCountMap  = new HashMap<>(inLength);
485   }
486
487}