/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2009, Red Hat Middleware LLC, and individual contributors
 * as indicated by the @author tags. See the copyright.txt file in the
 * distribution for a full listing of individual contributors.
 *
 * This is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of
 * the License, or (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this software; if not, write to the Free
 * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
 */
package org.jboss.cache.api.mvcc.repeatable_read;

import static org.testng.AssertJUnit.assertEquals;

import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import javax.transaction.Status;
import javax.transaction.TransactionManager;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jboss.cache.AbstractSingleCacheTest;
import org.jboss.cache.Cache;
import org.jboss.cache.CacheFactory;
import org.jboss.cache.CacheSPI;
import org.jboss.cache.DefaultCacheFactory;
import org.jboss.cache.Fqn;
import org.jboss.cache.UnitTestCacheFactory;
import org.jboss.cache.config.Configuration;
import org.jboss.cache.config.Configuration.CacheMode;
import org.jboss.cache.factories.UnitTestConfigurationFactory;
import org.jboss.cache.lock.IsolationLevel;
import org.jboss.cache.util.TestingUtil;
import org.testng.annotations.Test;

/**
 * ConcurrentRepeatableReadTest.
 * 
 * @author Galder Zamarreño
 */
@Test(groups = { "functional" }, testName = "api.mvcc.repeatable_read.ConcurrentRepeatableReadTest")
public class ConcurrentRepeatableReadTest extends AbstractSingleCacheTest 
{
   static final Log log = LogFactory.getLog(ConcurrentRepeatableReadTest.class);
   final ExecutorService executorService = Executors.newCachedThreadPool();
   
   @Override
   protected CacheSPI createCache() throws Exception 
   {
      UnitTestCacheFactory factory = new UnitTestCacheFactory();
      Configuration cfg = UnitTestConfigurationFactory.createConfiguration(CacheMode.LOCAL);
      cfg.setIsolationLevel(IsolationLevel.REPEATABLE_READ);      
      CacheSPI cache = (CacheSPI) factory.createCache(cfg, true, getClass());
      return cache;
   }

   public void testConcurrentUpdatesNoWriteSkew(Method m) throws Exception {
      final int nbWriters = 10;
      log.debug(m.getName());
      init();
      CyclicBarrier barrier = new CyclicBarrier(nbWriters + 1);
      List<Future<Void>> futures = new ArrayList<Future<Void>>(nbWriters);
      for (int i = 0; i < nbWriters; i++) {
         log.debug("Schedule execution");
         Future<Void> future = executorService.submit(new IncrementNoWriteSkew(barrier));
         futures.add(future);
      }
      barrier.await(); // wait for all threads to be ready
      barrier.await(); // wait for all threads to finish

      log.debug("All threads finished, let's shutdown the executor and check whether any exceptions were reported");
      for (Future<Void> future : futures) future.get();

      assertEquals(nbWriters, get());
   }

   public void testConcurrentUpdatesWriteSkew(Method m) throws Exception {
      final int nbWriters = 10;
      CacheSPI cache = null;
      try {
         log.debug(m.getName());
         UnitTestCacheFactory factory = new UnitTestCacheFactory();
         Configuration cfg = UnitTestConfigurationFactory.createConfiguration(CacheMode.LOCAL);
         cfg.setIsolationLevel(IsolationLevel.REPEATABLE_READ);
         cfg.setWriteSkewCheck(true);
         cache = (CacheSPI) factory.createCache(cfg, false, getClass());
         cache.start();         
         assert cache.getConfiguration().isWriteSkewCheck();
         init();
         CyclicBarrier barrier = new CyclicBarrier(nbWriters + 1);
         List<Future<Void>> futures = new ArrayList<Future<Void>>(nbWriters);
         for (int i = 0; i < nbWriters; i++) {
            log.debug("Schedule execution");
            Future<Void> future = executorService.submit(new IncrementWriteSkew(barrier));
            futures.add(future);
         }
         barrier.await(); // wait for all threads to be ready
         barrier.await(); // wait for all threads to finish

         log.debug("All threads finished, let's shutdown the executor and check whether any exceptions were reported");
         for (Future<Void> future : futures) future.get();         
      } finally {
         if (cache != null) TestingUtil.killCaches(cache);
      }
   }

   public void testConcurrentCreateRemove() throws Exception {
      final int totalElement = 100;
      final int totalTimes = 20;
      int writer = 10;
      int remover = 5;
      final CountDownLatch startSignalWriter = new CountDownLatch(1);
      final CountDownLatch startSignalOthers = new CountDownLatch(1);
      final CountDownLatch doneSignal = new CountDownLatch(writer + remover);
      final List<Exception> errors = Collections.synchronizedList(new ArrayList<Exception>());
      for (int i = 0; i < writer; i++)
      {
         final int index = i;
         Thread thread = new Thread()
         {
            public void run()
            {
               try
               {
                  startSignalWriter.await();
                  for (int j = 0; j < totalTimes; j++)
                  {
                     for (int i = 0; i < totalElement; i++)
                     {
                        cache.put(Fqn.fromElements("key" + i), "key" + i, "value" + i);
                        if (index == 0 && j == 0)
                        {
                           // The cache is full, we can launch the others
                           startSignalOthers.countDown();
                        }
                     }
                     sleep(50);
                  }
               }
               catch (Exception e)
               {
                  errors.add(e);
               }
               finally
               {
                  doneSignal.countDown();
               }
            }
         };
         thread.start();
      }
      startSignalWriter.countDown();
      for (int i = 0; i < remover; i++)
      {
         Thread thread = new Thread()
         {
            public void run()
            {
               try
               {
                  startSignalOthers.await();
                  for (int j = 0; j < totalTimes; j++)
                  {
                     for (int i = 0; i < totalElement; i++)
                     {
                        cache.removeNode(Fqn.fromElements("key" + i));
                     }
                     sleep(50);
                  }
               }
               catch (Exception e)
               {
                  errors.add(e);
               }
               finally
               {
                  doneSignal.countDown();
               }
            }
         };
         thread.start();
      }
      doneSignal.await();
      if (!errors.isEmpty())
      {
         for (Exception e : errors)
         {
            e.printStackTrace();
         }
         throw errors.get(0);
      }
   }

   private void init() throws Exception {
      TransactionManager tx = getTm();
      tx.begin();
      try {
         cache.put("/foo/mynode", "scalar", 0);
      } catch (Exception e) {
         tx.setRollbackOnly();
         throw e;
      } finally {
         if (tx.getStatus() == Status.STATUS_ACTIVE) tx.commit();
         else tx.rollback();
      }
   }

   private void incrementNoWriteSkew() throws Exception {
      TransactionManager tx = getTm();
      tx.begin();
      try {
         cache.getInvocationContext().getOptionOverrides().setForceWriteLock(true);
         int tmp = (Integer) cache.get("/foo/mynode", "scalar");
         tmp++;
         cache.put("/foo/mynode", "scalar", tmp);
      } catch (Exception e) {
         log.error("Unexpected", e);
         tx.setRollbackOnly();
         throw e;
      } finally {
         if (tx.getStatus() == Status.STATUS_ACTIVE) tx.commit();
         else tx.rollback();
      }
   }

   private void incrementWriteSkew() throws Exception {
      TransactionManager tx = getTm();
      tx.begin();
      try {
         cache.put("/foo/mynode", "_lockthisplease_", "_lockthisplease_");
      } catch (Exception e) {
         log.error("Unexpected", e);
         tx.setRollbackOnly();
         throw e;
      } finally {
         if (tx.getStatus() == Status.STATUS_ACTIVE) tx.commit();
         else tx.rollback();
      }
   }


   public int get() throws Exception {
      TransactionManager tx = getTm();
      tx.begin();
      try {
         int ret = (Integer) cache.get("/foo/mynode", "scalar");
         return ret;
      } catch (Exception e) {
         tx.setRollbackOnly();
         throw e;
      } finally {
         if (tx.getStatus() == Status.STATUS_ACTIVE) tx.commit();
         else tx.rollback();
      }
   }

   private TransactionManager getTm() {
      return cache.getConfiguration().getRuntimeConfig().getTransactionManager();
   }

   class IncrementNoWriteSkew implements Callable<Void> {
      private final CyclicBarrier barrier;

      public IncrementNoWriteSkew(CyclicBarrier barrier) {
         this.barrier = barrier;
      }

      public Void call() throws Exception {
         try {
            log.debug("Wait for all executions paths to be ready to perform calls");
            barrier.await();
            incrementNoWriteSkew();
            return null;
         } finally {
            log.debug("Wait for all execution paths to finish");
            barrier.await();
         }
      }
   }

   class IncrementWriteSkew implements Callable<Void> {
      private final CyclicBarrier barrier;

      public IncrementWriteSkew(CyclicBarrier barrier) {
         this.barrier = barrier;
      }

      public Void call() throws Exception {
         try {
            log.debug("Wait for all executions paths to be ready to perform calls");
            barrier.await();
            incrementWriteSkew();
            return null;
         } finally {
            log.debug("Wait for all execution paths to finish");
            barrier.await();
         }
      }
   }

   
}
