001package com.hfg.bio.phylogeny;
002
003import com.hfg.network.Edge;
004
005import java.util.Map;
006import java.util.HashMap;
007import java.util.Collection;
008import java.util.Iterator;
009
010//------------------------------------------------------------------------------
011/**
012 * NJ (Neighbor-Joining) method of phylogenetic tree construction.
013 * First described by <i>Saitou N, Nei M (1987). "The neighbor-joining method: a new
014 * method for reconstructing phylogenetic trees". Mol Biol Evol 4 (4): 406-425.</i>
015 * See <a href='http://en.wikipedia.org/wiki/Neighbor-joining'>wikipedia</a>.
016 * Note that distances in the resulting tree will not exactly match those from
017 * the input distance matrix.
018 * @author J. Alex Taylor, hairyfatguy.com
019 */
020//------------------------------------------------------------------------------
021// com.hfg Library
022//
023// This library is free software; you can redistribute it and/or
024// modify it under the terms of the GNU Lesser General Public
025// License as published by the Free Software Foundation; either
026// version 2.1 of the License, or (at your option) any later version.
027//
028// This library is distributed in the hope that it will be useful,
029// but WITHOUT ANY WARRANTY; without even the implied warranty of
030// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
031// Lesser General Public License for more details.
032//
033// You should have received a copy of the GNU Lesser General Public
034// License along with this library; if not, write to the Free Software
035// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
036//
037// J. Alex Taylor, President, Founder, CEO, COO, CFO, OOPS hairyfatguy.com
038// jataylor@hairyfatguy.com
039//------------------------------------------------------------------------------
040
041public class NJ implements TreeMethod
042{
043
044   //**************************************************************************
045   // PUBLIC METHODS
046   //**************************************************************************
047
048   //--------------------------------------------------------------------------
049   @Override
050   public String toString()
051   {
052      return getClass().getSimpleName();
053   }
054
055   //--------------------------------------------------------------------------
056   @Override
057   public boolean equals(Object inObj2)
058   {
059      return (inObj2 != null
060            && inObj2.getClass().equals(getClass()));
061   }
062
063   //--------------------------------------------------------------------------
064   public PhylogeneticTree constructTree(DistanceMatrix inDistanceMatrix)
065   {
066      int nodeIndex = 1;
067      Map<String, PhyloNode> nodeMap = new HashMap<>();
068
069      // Start with a star tree
070      //         A
071      //    F    |    B
072      //      \  |  /
073      //       \ | /
074      //        \|/
075      //        /|\
076      //       / | \
077      //      /  |  \
078      //    E    |    C
079      //         D
080      //
081
082      PhyloNode centerNode = new PhyloNode();
083      for (String otu : inDistanceMatrix.keySet())
084      {
085         PhyloNode otuNode = new PhyloNode().setLabel(otu);
086         nodeMap.put(otu, otuNode);
087         centerNode.addEdge(otuNode, null);
088      }
089
090      // Clone the distance matrix so we can consume it
091      DistanceMatrix matrix = inDistanceMatrix.clone();
092
093      while (matrix.numKeys() > 2)
094      {
095         Map<String, Float> netDivergenceMap = calculateNetDivergenceMap(matrix);
096         DistanceMatrix qMatrix = calculateQMatrix(matrix, netDivergenceMap);
097
098         // Note: OTU stands for 'Operational Taxonomic Unit'
099
100         Edge<String> shortestEdge = qMatrix.getShortestEdge();
101         String minOTU1 = shortestEdge.getFrom();
102         String minOTU2 = shortestEdge.getTo();
103         float  minDistance = matrix.getDistance(minOTU1, minOTU2);
104
105         String newNodeName = "_" + (nodeIndex++);
106         PhyloNode newNode = new PhyloNode().setLabel(newNodeName);
107         centerNode.addEdge(newNode, null);
108         nodeMap.put(newNodeName, newNode);
109
110         float distance1 = minDistance / 2
111                           + (netDivergenceMap.get(minOTU1) - netDivergenceMap.get(minOTU2)) / (2 * (matrix.numKeys() - 2));
112         float distance2 = minDistance - distance1;
113
114         // Deal with Negative branch lengths
115         // From : http://www.deduveinstitute.be/~opperd/private/neighbor.html
116         // "As the neighbor-joining algorithm seeks to represent the data in the form of an additive tree, it can assign
117         // a negative length to the branch. Here the interpretation of branch lengths as an estimated number of substitutions
118         // gets into difficulties. When this occurs it is adviced to set the branch length to zero and transfer the difference
119         // to the adjacent branch length so that the total distance between an adjacent pair of terminal nodes remains
120         // unaffected. This does not alter the overall topology of the tree (Kuhner and Felsenstein, 1994)."
121         if (distance1 < 0)
122         {
123            distance2 += -distance1;
124            distance1 = 0;
125         }
126
127         PhyloNode node1 = nodeMap.get(minOTU1);
128         node1.removeEdgeFrom(centerNode);
129         newNode.addEdge(node1, distance1);
130         nodeMap.remove(node1.getLabel());
131
132         PhyloNode node2 = nodeMap.get(minOTU2);
133         node2.removeEdgeFrom(centerNode);
134         newNode.addEdge(node2, distance2);
135         nodeMap.remove(node2.getLabel());
136
137         matrix.addKey(newNodeName);
138         for (String key : matrix.keySet())
139         {
140            if (key.equals(minOTU1) || key.equals(minOTU2) || key.equals(newNodeName)) continue;
141            float distance = (matrix.getDistance(key, minOTU1) + matrix.getDistance(key, minOTU2) - minDistance) / 2;
142            matrix.setDistance(newNodeName, key, distance);
143         }
144         matrix.removeKey(minOTU1);
145         matrix.removeKey(minOTU2);
146      }
147
148      // Now remove the center node and connect the last two nodes.
149      for (PhyloNode node : nodeMap.values())
150      {
151         node.removeEdgeFrom(centerNode);
152      }
153
154      Iterator<String> iter = nodeMap.keySet().iterator();
155      String otu1 = iter.next();
156      String otu2 = iter.next();
157
158      nodeMap.get(otu1).addEdge(nodeMap.get(otu2), matrix.getDistance(otu1, otu2));
159
160      PhylogeneticTree tree = new PhylogeneticTree();
161
162      PhyloNode startingNode = nodeMap.get(otu1);
163      if (nodeMap.get(otu2).getEdges().size() > startingNode.getEdges().size())
164      {
165         startingNode = nodeMap.get(otu2);
166      }
167
168      tree.setRootNode(startingNode);
169
170//      tree.rootTreeByMidpointMethod();
171//      tree.orderByNodeCount();
172
173      return tree;
174   }
175
176   //**************************************************************************
177   // PRIVATE METHODS
178   //**************************************************************************
179
180   //--------------------------------------------------------------------------
181   private static DistanceMatrix calculateQMatrix(DistanceMatrix inDistanceMatrix, Map<String, Float> inNetDivergenceMap)
182   {
183      DistanceMatrix qMatrix = new DistanceMatrix();
184
185      Collection<String> matrixKeys = inDistanceMatrix.keySet();
186      int N = matrixKeys.size();
187      for (String key : matrixKeys)
188      {
189         float netDivergence1 = inNetDivergenceMap.get(key);
190         for (String key2 : matrixKeys)
191         {
192            if (! key.equals(key2))
193            {
194               float distance = inDistanceMatrix.getDistance(key, key2) - (netDivergence1 + inNetDivergenceMap.get(key2)) / (float) (N - 2);
195               qMatrix.setDistance(key, key2, distance);
196            }
197         }
198      }
199
200      return qMatrix;
201   }
202
203
204   //--------------------------------------------------------------------------
205   private static Map<String, Float> calculateNetDivergenceMap(DistanceMatrix inDistanceMatrix)
206   {
207      Map<String, Float> netDivergenceMap = new HashMap<>(inDistanceMatrix.numKeys());
208
209      Collection<String> matrixKeys = inDistanceMatrix.keySet();
210      for (String key : matrixKeys)
211      {
212         float netDivergence = inDistanceMatrix.getNetDivergence(key);
213
214         netDivergenceMap.put(key, netDivergence);
215      }
216
217      return netDivergenceMap;
218   }
219}