Computer Vision News - April 2023
31 MONAI Generative Models epoch_loss_list = [] val_epoch_loss_list = [] scaler = GradScaler() total_start = time.time() for epoch in range(n_epochs): model.train() epoch_loss = 0 progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=70) progress_bar.set_description(f"Epoch {epoch}") for step, batch in progress_bar: images = batch["image"].to(device) optimizer.zero_grad(set_to_none=True) with autocast(enabled=True): # Generate random noise noise = torch.randn_like(images).to(device) # Create timesteps timesteps = torch.randint( 0, inferer.scheduler.num_train_timesteps, (images. shape[0],), device=images.device ).long() # Get model prediction noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps) loss = F.mse_loss(noise_pred.float(), noise.float()) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() progress_bar.set_postfix({"loss": epoch_loss / (step + 1)}) epoch_loss_list.append(epoch_loss / (step + 1)) if (epoch + 1) % val_interval == 0: model.eval() val_epoch_loss = 0 for step, batch in enumerate(val_loader): images = batch["image"].to(device) noise = torch.randn_like(images).to(device) with torch.no_grad(): with autocast(enabled=True): timesteps = torch.randint( 0, inferer.scheduler.num_train_timesteps, (images. shape[0],), device=images.device ).long() # Get model prediction noise_pred = inferer(inputs=images, diffusion_model=- model, noise=noise, timesteps=timesteps)
Made with FlippingBook
RkJQdWJsaXNoZXIy NTc3NzU=