瀏覽代碼

[DATALAD] Recorded changes

Lucas Gautheron 2 年之前
父節點
當前提交
cec6127f6a
共有 4 個文件被更改,包括 76 次插入10 次删除
  1. 2 0
      .gitignore
  2. 9 3
      code/corr.py
  3. 64 7
      code/plot.py
  4. 1 0
      requirements.txt

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+build
+.DS_Store

+ 9 - 3
code/corr.py

@@ -39,10 +39,10 @@ def compute_counts(parameters):
     intersection = AnnotationManager.intersection(
         am.annotations, ['vtc', annotator]
     )
-    intersection['onset'] = intersection.apply(lambda r: np.arange(r['range_onset'], r['range_offset'], 10000), axis = 1)
+    intersection['onset'] = intersection.apply(lambda r: np.arange(r['range_onset'], r['range_offset'], 15000), axis = 1)
     intersection = intersection.explode('onset')
     intersection['range_onset'] = intersection['onset']
-    intersection['range_offset'] = (intersection['range_onset']+10000).clip(upper = intersection['range_offset'])
+    intersection['range_offset'] = (intersection['range_onset']+15000).clip(upper = intersection['range_offset'])
 
     intersection['path'] = intersection.apply(
         lambda r: opj(project.path, 'annotations', r['set'], 'converted', r['annotation_filename']),
@@ -138,7 +138,9 @@ annotators = pd.read_csv('input/annotators.csv')
 annotators['path'] = annotators['corpus'].apply(lambda c: opj('input', c))
 counts = pd.concat([compute_counts(annotator) for annotator in annotators.to_dict(orient = 'records')])
 counts = counts.fillna(0)
+counts.to_csv('counts.csv')
 
+counts = counts.read_csv('counts.csv')
 truth = np.transpose([counts['count']['truth'][speaker].values for speaker in ['CHI', 'OCH', 'FEM', 'MAL']]).astype(int)
 vtc = np.transpose([counts['count']['vtc'][speaker].values for speaker in ['CHI', 'OCH', 'FEM', 'MAL']]).astype(int)
 
@@ -165,6 +167,10 @@ data = {
     'vtc': vtc.astype(int)
 }
 
+print(f"clips: {data['n_clips']}")
+print("true vocs: {}".format(np.sum(truth)))
+print("vtc vocs: {}".format(np.sum(vtc)))
+
 plt.scatter(data['truth'][:,0]+np.random.normal(0,0.1,truth.shape[0]), data['vtc'][:,0]+np.random.normal(0,0.1,truth.shape[0]))
 plt.scatter(data['truth'][:,1]+np.random.normal(0,0.1,truth.shape[0]), data['vtc'][:,1]+np.random.normal(0,0.1,truth.shape[0]))
 plt.scatter(data['truth'][:,2]+np.random.normal(0,0.1,truth.shape[0]), data['vtc'][:,2]+np.random.normal(0,0.1,truth.shape[0]))
@@ -232,7 +238,7 @@ init = {
     'betas': np.full((truth.shape[1], truth.shape[1]), 1.01)
 }
 
-num_chains = 2
+num_chains = 4
 
 posterior = stan.build(stan_code, data = data)
 fit = posterior.sample(num_chains = num_chains, num_samples = 4000)

+ 64 - 7
code/plot.py

@@ -1,14 +1,57 @@
 import pandas as pd
 import numpy as np
 
-from matplotlib import pyplot as plt
-import seaborn as sns
+import matplotlib
+import matplotlib.pyplot as plt
+matplotlib.use("pgf")
+matplotlib.rcParams.update({
+    "pgf.texsystem": "pdflatex",
+    'font.family': 'serif',
+    "font.serif" : "Times New Roman",
+    'text.usetex': True,
+    'pgf.rcfonts': False,
+})
+
+def set_size(width, fraction=1, ratio = None):
+    """ Set aesthetic figure dimensions to avoid scaling in latex.
+
+    Parameters
+    ----------
+    width: float
+            Width in pts
+    fraction: float
+            Fraction of the width which you wish the figure to occupy
+
+    Returns
+    -------
+    fig_dim: tuple
+            Dimensions of figure in inches
+    """
+    # Width of figure
+    fig_width_pt = width * fraction
+
+    # Convert from pt to inches
+    inches_per_pt = 1 / 72.27
+
+    # Golden ratio to set aesthetic figure height
+    if ratio is None:
+        ratio = (5 ** 0.5 - 1) / 2
+
+    # Figure width in inches
+    fig_width_in = fig_width_pt * inches_per_pt
+    # Figure height in inches
+    fig_height_in = fig_width_in * ratio
+
+    return fig_width_in, fig_height_in
+
 
 fit = pd.read_csv('fit.csv')
 
-fig = plt.figure(figsize=(8,8))
+fig = plt.figure(figsize=set_size(450, 1, 1))
 axes = [fig.add_subplot(4,4,i+1) for i in range(4*4)]
 
+speakers = ['CHI', 'OCH', 'FEM', 'MAL']
+
 for i in range(4*4):
     ax = axes[i]
     row = i//4+1
@@ -16,18 +59,32 @@ for i in range(4*4):
     label = f'confusion.{row}.{col}'
 
     ax.set_xticks([])
-    ax.set_yticks([])
     ax.set_xticklabels([])
+    ax.set_yticks([])
     ax.set_yticklabels([])
     ax.set_ylim(0,5)
     ax.set_xlim(0,1)
 
+    low = fit[label].quantile(0.0275)
+    high = fit[label].quantile(0.975)
+
+    if row == 1:
+        ax.xaxis.tick_top()
+        ax.set_xticks([0.5])
+        ax.set_xticklabels([speakers[col-1]])
+
     if row == 4:
-        ax.set_xticks(np.linspace(0,1,4, endpoint = False))
-        ax.set_xticklabels(np.linspace(0,1,4, endpoint = False))
+        ax.set_xticks(np.linspace(0.25,1,3, endpoint = False))
+        ax.set_xticklabels(np.linspace(0.25,1,3, endpoint = False))
+
+    if col == 1:
+        ax.set_yticks([2.5])
+        ax.set_yticklabels([speakers[row-1]])
 
     ax.hist(fit[label], bins = np.linspace(0,1,40), density = True)
+    ax.axvline(fit[label].mean(), linestyle = '--', linewidth = 0.5, color = '#333', alpha = 1)
+    ax.text(0.5, 4.5, f'{low:.2f} - {high:.2f}', ha = 'center', va = 'center')
 
 fig.subplots_adjust(wspace = 0, hspace = 0)
-
+plt.savefig('confusion_fit.pdf')
 plt.show()

+ 1 - 0
requirements.txt

@@ -0,0 +1 @@
+.git/annex/objects/12/Gf/MD5E-s19--2ee0a9c2f76b7dcd1f79cd2ab7022f73.txt/MD5E-s19--2ee0a9c2f76b7dcd1f79cd2ab7022f73.txt