메인 콘텐츠로 건너뛰기
Colab에서 실행해 보기 이 노트북에서는 W&B 아티팩트를 사용하여 ML 실험 파이프라인을 추적하는 방법을 알아봅니다. 동영상 튜토리얼을 보며 함께 따라 해 보세요.

아티팩트에 대하여

그리스의 암포라처럼, 아티팩트는 어떤 과정의 결과로 만들어진 객체입니다. ML에서 가장 중요한 아티팩트는 _데이터셋_과 _모델_입니다. 그리고 Cross of Coronado처럼, 이런 중요한 아티팩트는 박물관에 있어야 합니다. 즉, 여러분과 여러분의 팀, 더 나아가 전체 ML 커뮤니티가 이들로부터 배울 수 있도록 잘 분류하고 체계적으로 정리해야 합니다. 결국, 학습 실행을 추적하지 않으면 같은 학습을 반복할 수밖에 없습니다. Artifacts API를 사용하면 W&B Run의 출력으로 Artifact를 기록(log)하거나, 이 다이어그램처럼 Run의 입력으로 Artifact를 사용할 수 있습니다. 아래 예시에서는 학습 실행이 데이터셋을 입력으로 받아 모델을 출력으로 생성합니다.
Artifacts 워크플로 다이어그램
하나의 실행이 다른 실행의 출력을 입력으로 사용할 수 있으므로, ArtifactRun은 함께 방향 그래프(노드는 ArtifactRun으로 구성된 이분 DAG)를 형성합니다. 그리고 화살표는 Run이 소비하거나 생성하는 Artifact를 가리키면서 RunArtifact를 연결합니다.

아티팩트를 사용해 모델과 데이터셋 추적하기

설치 및 임포트

아티팩트는 W&B Python 라이브러리의 일부이며, 버전 0.9.2부터 제공됩니다. 대부분의 ML Python 스택과 마찬가지로 pip로 설치할 수 있습니다.
# wandb 버전 0.9.2 이상과 호환
!pip install wandb -qqq
!apt install tree
import os
import wandb

데이터셋 로깅(Log a Dataset)

먼저 몇 가지 아티팩트를 정의합니다. 이 예시는 PyTorch의 “Basic MNIST Example”를 기반으로 하지만, TensorFlow를 사용하거나 다른 프레임워크, 혹은 순수 Python만으로도 동일하게 구현할 수 있습니다. 먼저 Dataset을 다음과 같이 정의합니다:
  • 파라미터를 선택하기 위한 train 세트
  • 하이퍼파라미터를 선택하기 위한 validation 세트
  • 최종 모델을 평가하기 위한 test 세트
아래 첫 번째 셀은 이 세 개의 데이터셋을 정의합니다.
import random 

import torch
import torchvision
from torch.utils.data import TensorDataset
from tqdm.auto import tqdm

# 결정론적 동작 보장
torch.backends.cudnn.deterministic = True
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

# 디바이스 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 데이터 파라미터
num_classes = 10
input_shape = (1, 28, 28)

# MNIST 미러 목록에서 느린 미러 제거
torchvision.datasets.MNIST.mirrors = [mirror for mirror in torchvision.datasets.MNIST.mirrors
                                      if not mirror.startswith("http://yann.lecun.com")]

def load(train_size=50_000):
    """
    # 데이터 로드
    """

    # 데이터를 학습 세트와 테스트 세트로 분할
    train = torchvision.datasets.MNIST("./", train=True, download=True)
    test = torchvision.datasets.MNIST("./", train=False, download=True)
    (x_train, y_train), (x_test, y_test) = (train.data, train.targets), (test.data, test.targets)

    # 하이퍼파라미터 튜닝을 위한 검증 세트 분리
    x_train, x_val = x_train[:train_size], x_train[train_size:]
    y_train, y_val = y_train[:train_size], y_train[train_size:]

    training_set = TensorDataset(x_train, y_train)
    validation_set = TensorDataset(x_val, y_val)
    test_set = TensorDataset(x_test, y_test)

    datasets = [training_set, validation_set, test_set]

    return datasets
이는 이 예시 전체에서 반복해서 보게 될 패턴을 보여 줍니다: 데이터를 아티팩트로 로깅하는 코드를 그 데이터를 생성하는 코드 주위에 감싸도록 구성하는 방식입니다. 이 경우에는 데이터를 load하는 코드와 데이터를 load_and_log하는 코드가 분리되어 있습니다. 이는 권장되는 모범 사례입니다. 이 데이터셋들을 아티팩트로 로깅하려면 다음과 같이만 하면 됩니다.
  1. wandb.init()으로 Run을 생성하고 (L4),
  2. 데이터셋을 위한 Artifact를 생성하고 (L10),
  3. 관련된 file들을 저장하고 로깅합니다 (L20, L23).
아래 코드 셀의 예제를 확인한 다음, 더 자세한 내용은 이후 섹션을 펼쳐서 살펴보세요.
def load_and_log():

    # 실행을 시작합니다. 유형으로 레이블을 지정하고 속할 프로젝트를 설정합니다
    with wandb.init(project="artifacts-example", job_type="load-data") as run:
        
        datasets = load()  # 데이터셋을 로드하는 별도의 코드
        names = ["training", "validation", "test"]

        # 🏺 아티팩트를 생성합니다
        raw_data = wandb.Artifact(
            "mnist-raw", type="dataset",
            description="Raw MNIST dataset, split into train/val/test",
            metadata={"source": "torchvision.datasets.MNIST",
                      "sizes": [len(dataset) for dataset in datasets]})

        for name, data in zip(names, datasets):
            # 🐣 아티팩트에 새 파일을 저장하고 내용을 씁니다.
            with raw_data.new_file(name + ".pt", mode="wb") as file:
                x, y = data.tensors
                torch.save((x, y), file)

        # ✍️ 아티팩트를 W&B에 저장합니다.
        run.log_artifact(raw_data)

load_and_log()

wandb.init()

Artifact들을 생성할 Run을 만들 때, 그 Run이 어떤 project에 속할지 지정해야 합니다. 워크플로우에 따라 프로젝트의 범위는 car-that-drives-itself처럼 클 수도 있고 iterative-architecture-experiment-117처럼 작을 수도 있습니다.
모범 사례: 가능하다면, 동일한 Artifact를 공유하는 모든 Run을 하나의 프로젝트 안에 두세요. 이렇게 하면 구성이 단순해집니다. 그래도 걱정하지 마세요 — Artifact는 프로젝트 간에 이동 가능합니다.
여러 종류의 작업을 추적하기 위해, Run을 생성할 때 job_type을 지정해 두는 것이 유용합니다. 이렇게 하면 아티팩트 그래프를 깔끔하게 유지할 수 있습니다.
모범 사례: job_type은 충분히 설명적이어야 하고 파이프라인의 단일 단계에 대응해야 합니다. 여기서는 데이터를 load하는 단계와 데이터를 preprocess하는 단계를 분리합니다.

wandb.Artifact

무언가를 Artifact로 기록하려면, 먼저 Artifact 객체를 만들어야 합니다. Artifact에는 name이 있습니다. 첫 번째 인자는 이 name을 설정합니다.
모범 사례: name은 설명적이면서도 기억하고 입력하기 쉬워야 합니다. 코드의 변수 이름과 대응되도록 하이픈으로 구분된 이름을 사용하는 방식을 권장합니다.
또한 type도 있습니다. Runjob_type과 마찬가지로, 이는 RunArtifact로 이루어진 그래프를 구성하는 데 사용됩니다.
모범 사례: type은 단순해야 합니다. mnist-data-YYYYMMDD보다는 dataset 또는 model과 같은 값을 사용하세요.
또한 사전(dictionary) 형태로 description과 일부 metadata를 지정할 수 있습니다. metadata는 JSON으로 직렬화 가능하기만 하면 됩니다.
모범 사례: metadata는 가능한 한 자세하게 작성하는 것이 좋습니다.

artifact.new_file and run.log_artifact

Artifact 객체를 만들었으면, 그 안에 파일을 추가해야 합니다. 맞습니다. _file_이 아니라 _files_입니다. Artifact는 디렉터리처럼 구조화되어 있고, 파일과 하위 디렉터리를 포함합니다.
모범 사례: 가능하다면 항상 Artifact의 내용을 여러 파일로 나누세요. 이렇게 하면 나중에 확장해야 할 때 도움이 됩니다.
new_file 메서드를 사용하면 파일을 쓰는 작업과 그것을 Artifact에 연결하는 작업을 동시에 수행할 수 있습니다. 아래에서는 두 단계를 분리하는 add_file 메서드를 사용할 것입니다. 모든 파일을 추가한 뒤에는 log_artifact를 호출해 wandb.ai에 로깅해야 합니다. 출력에 몇 개의 URL이 나타나는 것을 볼 수 있는데, 그중에는 Run 페이지에 해당하는 URL도 있습니다. 여기에서 해당 실행의 결과를 확인할 수 있으며, 로그된 Artifact도 모두 볼 수 있습니다. 아래에서 Run 페이지의 다른 구성 요소를 더 잘 활용하는 예제를 살펴보겠습니다.

로깅된 데이터셋 아티팩트 사용하기

W&B의 Artifact는 박물관의 아티팩트와 달리, 단순히 보관하는 것이 아니라 사용하도록 설계되어 있습니다. 실제로 어떻게 동작하는지 살펴보겠습니다. 아래 셀은 원본(raw) 데이터셋을 입력으로 받아 이를 사용해 preprocess된 데이터셋을 생성하는 파이프라인 단계를 정의합니다. normalize가 적용되고, 올바른 형태로 변환된 데이터셋입니다. 다시 한 번, 코드의 핵심 로직인 preprocesswandb와 상호작용하는 코드와 분리해 둔 점에 주목하세요.
def preprocess(dataset, normalize=True, expand_dims=True):
    """
    ## 데이터 준비
    """
    x, y = dataset.tensors

    if normalize:
        # 이미지를 [0, 1] 범위로 스케일링
        x = x.type(torch.float32) / 255

    if expand_dims:
        # 이미지의 shape이 (1, 28, 28)이 되도록 설정
        x = torch.unsqueeze(x, 1)
    
    return TensorDataset(x, y)
이제 이 preprocess 단계를 wandb.Artifact 로깅으로 계측하는 코드를 살펴보겠습니다. 아래 예제에서는 새로운 개념인 Artifactuse할 뿐만 아니라, 이전 단계와 동일하게 log도 수행합니다. ArtifactRun(실행)의 입력이자 출력입니다. 이전 단계와는 다른 종류의 작업임을 명확히 하기 위해 새로운 job_typepreprocess-data를 사용합니다.
def preprocess_and_log(steps):

    with wandb.init(project="artifacts-example", job_type="preprocess-data") as run:

        processed_data = wandb.Artifact(
            "mnist-preprocess", type="dataset",
            description="Preprocessed MNIST dataset",
            metadata=steps)
         
        # ✔️ 사용할 아티팩트 선언
        raw_data_artifact = run.use_artifact('mnist-raw:latest')

        # 📥 필요한 경우 아티팩트 다운로드
        raw_dataset = raw_data_artifact.download()
        
        for split in ["training", "validation", "test"]:
            raw_split = read(raw_dataset, split)
            processed_dataset = preprocess(raw_split, **steps)

            with processed_data.new_file(split + ".pt", mode="wb") as file:
                x, y = processed_dataset.tensors
                torch.save((x, y), file)

        run.log_artifact(processed_data)


def read(data_dir, split):
    filename = split + ".pt"
    x, y = torch.load(os.path.join(data_dir, filename))

    return TensorDataset(x, y)
여기서 주목해야 할 점 중 하나는 전처리 stepspreprocessed_data와 함께 metadata로 저장된다는 것입니다. 실험을 재현 가능하게 만들고 싶다면, 가능한 한 많은 메타데이터를 저장해 두는 것이 좋습니다. 또한, 우리의 데이터셋이 “large artifact”이긴 하지만, download 단계는 1초도 채 걸리지 않습니다. 자세한 내용은 아래의 마크다운 셀을 펼쳐 확인하세요.
steps = {"normalize": True,
         "expand_dims": True}

preprocess_and_log(steps)

run.use_artifact()

이 단계는 더 간단합니다. 사용자는 Artifactname과 약간의 추가 정보만 알면 됩니다. 여기서 그 “약간의 추가 정보”란, 사용하려는 특정 버전의 Artifact에 해당하는 alias입니다. 기본적으로 마지막으로 업로드된 버전에는 latest 태그가 붙습니다. 또는 v0/v1 등으로 더 이전 버전을 선택할 수도 있고, best 또는 jit-script와 같은 사용자 정의 alias를 지정할 수도 있습니다. Docker Hub 태그와 마찬가지로, alias는 이름과 :로 구분되므로, 우리가 사용하려는 아티팩트는 mnist-raw:latest입니다.
모범 사례: alias는 짧고 간결하게 유지하세요. 특정 속성을 만족하는 아티팩트를 사용하려면 latestbest와 같은 사용자 정의 alias를 사용하세요.

artifact.download

이제 download 호출이 걱정될 수 있습니다. 만약 또 다른 사본을 다운로드하면, 메모리 사용량이 두 배로 늘어나는 것 아닐까요? 걱정하지 마세요. 실제로 어떤 것도 다운로드하기 전에, 먼저 올바른 버전이 로컬에 있는지 확인합니다. 이는 토렌트git을 이용한 버전 관리의 기반 기술과 동일한 해싱(hashing)을 사용합니다. 아티팩트가 생성되고 로깅될 때마다, 작업 디렉터리에 있는 artifacts라는 폴더가 하위 디렉터리로 채워지기 시작하며, 각 아티팩트마다 하나씩 생성됩니다. !tree artifacts 명령으로 그 내용을 확인해 보세요:
!tree artifacts

아티팩트 페이지

이제 Artifact를 기록하고 사용해 보았으니, 실행 페이지의 Artifacts 탭을 살펴보세요. wandb 출력에 표시된 실행 페이지 URL로 이동한 다음 왼쪽 사이드바에서 “Artifacts” 탭을 선택하세요 (데이터베이스 아이콘이 있는 항목으로, 위로 차곡차곡 쌓인 세 개의 하키 퍽처럼 보입니다). Input Artifacts 테이블이나 Output Artifacts 테이블에서 행을 하나 클릭한 다음, (Overview, Metadata) 탭을 확인하여 Artifact에 대해 기록된 모든 내용을 살펴보세요. 특히 Graph View를 추천합니다. 기본적으로, Artifacttype과 실행의 job_type을 두 종류의 노드로 하는 그래프를 표시하며, 소비와 생성 관계를 화살표로 표시합니다.

모델 로깅

Artifact API가 어떻게 동작하는지 이해하기에는 지금까지로도 충분하지만, 파이프라인 끝까지 예제를 따라가 보면서 Artifact(아티팩트)가 ML 워크플로를 어떻게 개선할 수 있는지 살펴보겠습니다. 이 첫 번째 셀에서는 PyTorch로 DNN model을 구성합니다. 아주 간단한 ConvNet입니다. 우선 model만 초기화하고, 학습은 진행하지 않겠습니다. 이렇게 하면 나머지 조건은 그대로 두고 학습만 반복해서 수행할 수 있습니다.
from math import floor

import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(self, hidden_layer_sizes=[32, 64],
                  kernel_sizes=[3],
                  activation="ReLU",
                  pool_sizes=[2],
                  dropout=0.5,
                  num_classes=num_classes,
                  input_shape=input_shape):
      
        super(ConvNet, self).__init__()

        self.layer1 = nn.Sequential(
              nn.Conv2d(in_channels=input_shape[0], out_channels=hidden_layer_sizes[0], kernel_size=kernel_sizes[0]),
              getattr(nn, activation)(),
              nn.MaxPool2d(kernel_size=pool_sizes[0])
        )
        self.layer2 = nn.Sequential(
              nn.Conv2d(in_channels=hidden_layer_sizes[0], out_channels=hidden_layer_sizes[-1], kernel_size=kernel_sizes[-1]),
              getattr(nn, activation)(),
              nn.MaxPool2d(kernel_size=pool_sizes[-1])
        )
        self.layer3 = nn.Sequential(
              nn.Flatten(),
              nn.Dropout(dropout)
        )

        fc_input_dims = floor((input_shape[1] - kernel_sizes[0] + 1) / pool_sizes[0]) # 레이어 1의 출력 크기
        fc_input_dims = floor((fc_input_dims - kernel_sizes[-1] + 1) / pool_sizes[-1]) # 레이어 2의 출력 크기
        fc_input_dims = fc_input_dims*fc_input_dims*hidden_layer_sizes[-1] # 레이어 3의 출력 크기

        self.fc = nn.Linear(fc_input_dims, num_classes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.fc(x)
        return x
여기서는 실행을 추적하기 위해 W&B를 사용하므로, run.config 객체를 사용해 모든 하이퍼파라미터를 저장합니다. 해당 config 객체의 dict 형식 버전은 매우 유용한 메타데이터이므로, 반드시 포함하세요.
def build_model_and_log(config):
    with wandb.init(project="artifacts-example", job_type="initialize", config=config) as run:
        config = run.config
        
        model = ConvNet(**config)

        model_artifact = wandb.Artifact(
            "convnet", type="model",
            description="Simple AlexNet style CNN",
            metadata=dict(config))

        torch.save(model.state_dict(), "initialized_model.pth")
        # ➕ 파일을 아티팩트에 추가하는 또 다른 방법
        model_artifact.add_file("initialized_model.pth")

        run.save("initialized_model.pth")

        run.log_artifact(model_artifact)

model_config = {"hidden_layer_sizes": [32, 64],
                "kernel_sizes": [3],
                "activation": "ReLU",
                "pool_sizes": [2],
                "dropout": 0.5,
                "num_classes": 10}

build_model_and_log(model_config)

artifact.add_file()

데이터셋 로깅 예제에서처럼 new_file을 생성하면서 동시에 Artifact에 추가하는 대신, 파일을 한 단계에서 먼저 작성하고 (여기서는 torch.save), 그 다음 단계에서 해당 파일을 Artifactadd할 수도 있습니다.
모범 사례: 가능한 경우 중복을 방지하기 위해 new_file을 사용하세요.

기록된 모델 아티팩트 사용하기

datasetuse_artifact를 호출했던 것처럼, 다른 실행에서 사용하기 위해 initialized_model에도 use_artifact를 호출할 수 있습니다. 이번에는 modeltrain해 보겠습니다. 자세한 내용은 PyTorch로 W&B 계측하기 Colab을 참고하세요.
import wandb
import torch.nn.functional as F

def train(model, train_loader, valid_loader, config):
    optimizer = getattr(torch.optim, config.optimizer)(model.parameters())
    model.train()
    example_ct = 0
    for epoch in range(config.epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

            example_ct += len(data)

            if batch_idx % config.batch_log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0%})]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    batch_idx / len(train_loader), loss.item()))
                
                train_log(loss, example_ct, epoch)

        # 각 에포크마다 검증 세트에서 모델을 평가합니다
        loss, accuracy = test(model, valid_loader)  
        test_log(loss, accuracy, example_ct, epoch)

    
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum')  # 배치 손실 합산
            pred = output.argmax(dim=1, keepdim=True)  # 최대 로그 확률의 인덱스 가져오기
            correct += pred.eq(target.view_as(pred)).sum()

    test_loss /= len(test_loader.dataset)

    accuracy = 100. * correct / len(test_loader.dataset)
    
    return test_loss, accuracy


def train_log(loss, example_ct, epoch):
    loss = float(loss)

    # 핵심 동작이 이루어지는 부분
    with wandb.init(project="artifacts-example", job_type="train") as run:
        run.log({"epoch": epoch, "train/loss": loss}, step=example_ct)
        print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    

def test_log(loss, accuracy, example_ct, epoch):
    loss = float(loss)
    accuracy = float(accuracy)

    # 핵심 동작이 이루어지는 부분
    with wandb.init() as run:
        run.log({"epoch": epoch, "validation/loss": loss, "validation/accuracy": accuracy}, step=example_ct)
        print(f"Loss/accuracy after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}/{accuracy:.3f}")
이번에는 Artifact를 생성하는 서로 다른 두 개의 Run을 실행합니다. 첫 번째 Runmodeltrain하는 작업을 마치면, secondtrained-model Artifact를 가져와 test_dataset에서 성능을 evaluate합니다. 또한, 네트워크가 가장 혼란스러워하는 32개의 예제를 뽑아낼 것입니다 — 즉, categorical_crossentropy가 가장 높은 예제들입니다. 이는 데이터셋과 모델의 문제를 진단하는 데 좋은 방법입니다.
def evaluate(model, test_loader):
    """
    ## 학습된 모델 평가
    """

    loss, accuracy = test(model, test_loader)
    highest_losses, hardest_examples, true_labels, predictions = get_hardest_k_examples(model, test_loader.dataset)

    return loss, accuracy, highest_losses, hardest_examples, true_labels, predictions

def get_hardest_k_examples(model, testing_set, k=32):
    model.eval()

    loader = DataLoader(testing_set, 1, shuffle=False)

    # 데이터셋의 각 항목에 대한 손실 및 예측값 가져오기
    losses = None
    predictions = None
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.cross_entropy(output, target)
            pred = output.argmax(dim=1, keepdim=True)
            
            if losses is None:
                losses = loss.view((1, 1))
                predictions = pred
            else:
                losses = torch.cat((losses, loss.view((1, 1))), 0)
                predictions = torch.cat((predictions, pred), 0)

    argsort_loss = torch.argsort(losses, dim=0)

    highest_k_losses = losses[argsort_loss[-k:]]
    hardest_k_examples = testing_set[argsort_loss[-k:]][0]
    true_labels = testing_set[argsort_loss[-k:]][1]
    predicted_labels = predictions[argsort_loss[-k:]]

    return highest_k_losses, hardest_k_examples, true_labels, predicted_labels
이러한 로깅 함수들은 새로운 Artifact 기능을 추가하지 않으므로, 여기서는 별도로 설명하지 않겠습니다: 단지 Artifact에 대해 use, download, log만 할 뿐입니다.
from torch.utils.data import DataLoader

def train_and_log(config):

    with wandb.init(project="artifacts-example", job_type="train", config=config) as run:
        config = run.config

        data = run.use_artifact('mnist-preprocess:latest')
        data_dir = data.download()

        training_dataset =  read(data_dir, "training")
        validation_dataset = read(data_dir, "validation")

        train_loader = DataLoader(training_dataset, batch_size=config.batch_size)
        validation_loader = DataLoader(validation_dataset, batch_size=config.batch_size)
        
        model_artifact = run.use_artifact("convnet:latest")
        model_dir = model_artifact.download()
        model_path = os.path.join(model_dir, "initialized_model.pth")
        model_config = model_artifact.metadata
        config.update(model_config)

        model = ConvNet(**model_config)
        model.load_state_dict(torch.load(model_path))
        model = model.to(device)
 
        train(model, train_loader, validation_loader, config)

        model_artifact = wandb.Artifact(
            "trained-model", type="model",
            description="Trained NN model",
            metadata=dict(model_config))

        torch.save(model.state_dict(), "trained_model.pth")
        model_artifact.add_file("trained_model.pth")
        run.save("trained_model.pth")

        run.log_artifact(model_artifact)

    return model

    
def evaluate_and_log(config=None):
    
    with wandb.init(project="artifacts-example", job_type="report", config=config) as run:
        data = run.use_artifact('mnist-preprocess:latest')
        data_dir = data.download()
        testing_set = read(data_dir, "test")

        test_loader = torch.utils.data.DataLoader(testing_set, batch_size=128, shuffle=False)

        model_artifact = run.use_artifact("trained-model:latest")
        model_dir = model_artifact.download()
        model_path = os.path.join(model_dir, "trained_model.pth")
        model_config = model_artifact.metadata

        model = ConvNet(**model_config)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        loss, accuracy, highest_losses, hardest_examples, true_labels, preds = evaluate(model, test_loader)

        run.summary.update({"loss": loss, "accuracy": accuracy})

        run.log({"high-loss-examples":
            [wandb.Image(hard_example, caption=str(int(pred)) + "," +  str(int(label)))
             for hard_example, pred, label in zip(hardest_examples, preds, true_labels)]})
train_config = {"batch_size": 128,
                "epochs": 5,
                "batch_log_interval": 25,
                "optimizer": "Adam"}

model = train_and_log(train_config)
evaluate_and_log()