ThreadLocal源码解析

ThreadLocal提供线程局部变量, 通过空间换时间的方式保证线程安全

总结

一开始比较理解的是, ThreadThreadLocalMapThreadLocal和key之间的关系 但实际上是, 每一个 Thread中, 存在一个 ThreadLocalMap, map中存放的key就是 ThreadLocal

关系图如图所示: 76ED0F997F3D4E6E901F88E711D5CFFC-2021-08-31-14:02:19.png

Thread消失后, ThreadLocal.ThreadLocalMap将会进行垃圾回收

ThreadLocal源码

/**
 * 每一个Thread中都存在一个ThreadLocalMap, 在第一次通过ThreadLocal使用时, 就会为其初始化ThreadLocalMap
 * ThreadLocalMap中的 key为ThreadLocal, value是对应的值
 * (通过ThreadLocal::get会当前的ThreadLocal当做key, 拿到当前线程中对应的值)
 * ThreadLocalMap当中的Entry是一个弱引用, 当进行垃圾回收时, 就会回收掉弱引用包装的引用对象
 */
public class ThreadLocal<T> {
  // 保存ThreadLocal的hashCode
  private final int threadLocalHashCode = nextHashCode();
  // 通过原子类递增HashCode
  private static AtomicInteger nextHashCode = new AtomicInteger();
  // HashCode递增值
  private static final int HASH_INCREMENT = 0x61c88647;
  // 通过该方法进行HashCode的递增
  private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
  }

  // 无参构造方法
  public ThreadLocal() {
  }

  // 往ThreadLocal中设置一个值
  public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 通过Thread.threadLocals拿到当前线程的ThreadLocalMap引用
    ThreadLocalMap map = getMap(t);
    if (map != null)
      // 如果map不为空, 则直接设置值, key为当前的ThreadLocal, value为入参
      map.set(this, value);
    else
      // 如果map为空, 则创建一个ThreadLocalMap
      createMap(t, value);
  }

  // 创建ThreadLocalMap, 并将引用赋值给Thread.threadLocals
  void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
  }

  // 从ThreadLocal中获取值
  public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 通过Thread.threadLocals拿到当前线程的ThreadLocalMap引用
    ThreadLocalMap map = getMap(t);
    // 如果map不为空
    if (map != null) {
      // 通过当前ThreadLocal拿到Entry
      ThreadLocalMap.Entry e = map.getEntry(this);
      if (e != null) {
        @SuppressWarnings("unchecked")
        // 拿到Entry中的value返回
        T result = (T)e.value;
        return result;
      }
    }
    // 如果map为空, 则初始化一个值返回
    return setInitialValue();
  }

  // 从ThreadLocal中删除值
  public void remove() {
    // 通过Thread.threadLocals拿到当前线程的ThreadLocalMap引用
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
      // 删除ThreadLocalMap的key为当前ThreadLocal的值
      m.remove(this);
  }

  
  // ===========内部实现类: ThreadLocalMap============
  static class ThreadLocalMap {
    // 初始容量为16, 必须是2的次幂
    private static final int INITIAL_CAPACITY = 16;
    // 会调整大小, 必须是2的次幂
    private Entry[] table;
    // 记录map中的个数
    private int size = 0;
    // 要调整大小的下一个大小值(实际上, 当size大于容量的一般时, 就需要进行扩容)
    // 因为 size >= len * (2/3) * (3/4) = len * 1/2
    private int threshold;

    // 构造方法
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
      table = new Entry[INITIAL_CAPACITY];
      int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
      table[i] = new Entry(firstKey, firstValue);
      size = 1;
      // 设置下次要扩容的阈值
      setThreshold(INITIAL_CAPACITY);
    }

    // 继承WeakReference, 当垃圾收集器回收时, 就会直接回收掉弱引用当中的引用对象
    static class Entry extends WeakReference<ThreadLocal<?>> {
      // 以ThreadLocal作为key, 关联的vlaue
      Object value;

      // Entry的构造方法
      Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
      }
    }

    // 通过key获取Entry
    private Entry getEntry(ThreadLocal<?> key) {
      // 通过hashCode确定index
      int i = key.threadLocalHashCode & (table.length - 1);
      Entry e = table[i];
      // 如果Entry不为空 && key相关
      if (e != null && e.get() == key)
        // 返回Entry
        return e;
      else
        // 如果index没找到匹配的Entry, 则通过挨个遍历的方式寻找
        return getEntryAfterMiss(key, i, e);
    }

		// 通过遍历的方式获取key对应的Entry
    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
      Entry[] tab = table;
      int len = tab.length;

      while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
          // 如果找到了, 则返回
          return e;
        if (k == null)
          // 如果引用对象已经被gc回收, 则消除掉旧Entry
          expungeStaleEntry(i);
        else
          // 通过挨个遍历的方式寻找
          i = nextIndex(i, len);
        e = tab[i];
      }
      return null;
    }
  
    // 设置key和value
    private void set(ThreadLocal<?> key, Object value) {

      Entry[] tab = table;
      int len = tab.length;
      // 通过hashCode确定index
      int i = key.threadLocalHashCode & (len-1);

      // 从index开始, 挨个循环寻找对应的不为空的Entry(退出条件为Entry为空)
      for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        // 如果key相等, 则更新value
        if (k == key) {
          e.value = value;
          return;
        }

        // 如果key为空, 则表示该引用对象已被gc回收, 则替换掉就的Entry
        if (k == null) {
          replaceStaleEntry(key, value, i);
          return;
        }
      }

      // 直到有空位, 创建新的Entry赋值到数组中
      tab[i] = new Entry(key, value);
      int sz = ++size;
      if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
    }

    // 删除key
    private void remove(ThreadLocal<?> key) {
      Entry[] tab = table;
      int len = tab.length;
      // 找到key对应的index
      int i = key.threadLocalHashCode & (len-1);
      // 从index开始, 挨个循环寻找对应的不为空的Entry(退出条件为Entry为空)
      for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        // 如果key相等, 则清除
        if (e.get() == key) {
          e.clear();
          expungeStaleEntry(i);
          return;
        }
      }
    }
  
    // 每次扩容两倍
    private void resize() {
      Entry[] oldTab = table;
      int oldLen = oldTab.length;
      int newLen = oldLen * 2;
      // 创建2倍大小的新数组
      Entry[] newTab = new Entry[newLen];
      int count = 0;

      // 迁移旧数据
      for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        // 如果为null, 则不需要迁移, 说明已经被gc回收了
        if (e != null) {
          // e为Entry, 继承自WeakReference, 调用get方法为了获取弱引用包装的引用对象
          // get方法: 如果此引用对象已被程序或垃圾收集器清除,则此方法返回null。
          ThreadLocal<?> k = e.get();
          if (k == null) {
            // 如果key为空了, 则将vlaue也值为空, 能够让gc进行垃圾收回
            e.value = null;
          } else {
            // 如果不为空, 则通过hashCode重新确定index
            int h = k.threadLocalHashCode & (newLen - 1);
            // 在新数组中, 从当前index开始挨个循环寻找空位, 直到找到空位
            while (newTab[h] != null)
              h = nextIndex(h, newLen);
            // 赋值
            newTab[h] = e;
            count++;
          }
        }
      }

      // 重新设置扩容的阈值
      setThreshold(newLen);
      size = count;
      // 更新引用
      table = newTab;
    }
  
    // 用于寻找下一个index, 传入当前index和数组长度
    private static int nextIndex(int i, int len) {
      // 挨个寻找, 如果大于数组总长度, 则将回到数组开始位置
      return ((i + 1 < len) ? i + 1 : 0);
    }
  }

}
0条评论
头像
ICP证 : 浙ICP备18021271号