HiveBrain v1.2.0
Get Started
← Back to all entries
patterncsharpMinor

Task based PowerShell cmdlet

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
cmdletpowershelltaskbased

Problem

I am learning about Tasks and also writing PowerShell cmdlets in C#. I found that connecting to remote machines was very slow, so I wrote this bit of code to speed it up using Tasks. I am hoping to get some feedback regarding the parallel processing and any other pointers in general.

The code is committed to a Github repo.

TaskCmdlet.cs

```
using System;
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation;
using System.Threading;
using System.Threading.Tasks;

namespace PoshTasks.Cmdlets
{
public abstract class TaskCmdlet : Cmdlet where TIn : class
where TOut : class
{
#region Parameters

[Parameter(ValueFromPipeline = true)]
public TIn[] InputObject { get; set; }

#endregion

#region Abstract methods

///
/// Performs an action on
///
/// The to be processed; null if not processing input
/// A
protected abstract TOut ProcessTask(TIn input = null);

#endregion

#region Virtual methods

///
/// Generates a collection of tasks to be processed
///
/// A collection of tasks
protected virtual IEnumerable> GenerateTasks()
{
List> tasks = new List>();

if (InputObject != null)
foreach (TIn input in InputObject)
tasks.Add(Task.Run(() => ProcessTask(input)));
else
tasks.Add(Task.Run(() => ProcessTask()));

return tasks;
}

///
/// Performs the pipeline output for this cmdlet
///
///
protected virtual void PostProcessTask(TOut result)
{
WriteObject(result, true);
}

#endregion

#region Processing

///
/// Processes cmdlet operation
///
protected override void ProcessRecord()
{
IEnumerable> tasks = GenerateTasks();

Solution

General conventions

The old good #region. Most people (including me) consider them rather bad then good. You should avoid them.

if (InputObject != null)
  foreach (TIn input in InputObject)
      tasks.Add(Task.Run(() => ProcessTask(input)));
else
  tasks.Add(Task.Run(() => ProcessTask()));


There's not even single curly brace {} ;-) I wouldn't complain if it was python but in C# you should always use them. Omitting them can cause a real headache.

protected virtual IEnumerable> GenerateTasks()
{
  List> tasks = new List>();

  if (InputObject != null)
      foreach (TIn input in InputObject)
          tasks.Add(Task.Run(() => ProcessTask(input)));
  else
      tasks.Add(Task.Run(() => ProcessTask()));

  return tasks;
}


In cases like this you can use yead return which greatly simplifies the code:

protected virtual IEnumerable> CreateProcessTasks()
{
    if (InputObject == null) 
    {
        yield return Task.Run(() => ProcessTask());
        yield break;
    }

    foreach (var input in InputObject)
    {
        yield return Task.Run(() => ProcessTask(input));
    }
}


I think this method doesn't need to be virtual. Generating tasks in not something you'd like to implement in each derived class. Consider changing its name to CreateProcessTasks as this is what it does. Generate sounds like it would create some random tasks.

async/await

In order for the async/await to work you need to actually await something but I couldn't find it in your code. Let's try to fix that and introduce few other changes that make your code look better.

I start with the Interleaved method... and you actually don't need it. Everything it does can be reduced to a single line:

var results = await Task.WhenAll(tasks.ToArray());


Where do I put it? I move this one to the ProcessRecordCore method that after this adjustment now looks like this:

protected override void ProcessRecord()
{
    var errorRecords = Task.Run(async () => await ProcessRecordCore()).Result;

    foreach (var errorRecord in errorRecords)
    {
        WriteError(errorRecord);    
    }       
}

private async Task> ProcessRecordCore()
{
    var tasks = CreateProcessTasks();

    var results = await Task.WhenAll(tasks.ToArray());

    var errorRecords = new BlockingCollection();

    foreach (var result in results)
    {
        try
        {
            PostProcessTask(result);
        }
        catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
        {
            // do nothing if pipeline stops
        }
        catch (Exception e)
        {
            errorRecords.Add(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
        }
    }

    return errorRecords;
}


Notice that ProcessRecordCore it's now marked as async so you can await for it to complete and the ProcessRecord uses the .Wait() method.

There are two more methods that can be simplified. The first one is the ProcessTask method where you can use the ?: ternary operator and don't need the if.

protected override ServiceController[] ProcessTask(string server)
{
    var services = ServiceController.GetServices(server);

    return 
        Name == null
        ? services
        : services.Where(s => Name.Contains(s.DisplayName)).ToArray();
}


or you can go crazy and make it a one-liner:

return services.Where(s => Name == null || Name.Contains(s.DisplayName)).ToArray();


The other one is the PostProcessTask method that could use some vars (like the rest of the code):

protected override void PostProcessTask(ServiceController[] result)
{
    var services = new List();

    foreach (var service in result)
    {
        services.Add(new
        {
            Name = service.DisplayName,
            Status = service.Status,
            ComputerName = service.MachineName,
            CanPause = service.CanPauseAndContinue
        });
    }

    WriteObject(services, true);
}


IProgress interface

To see the errors right away you may try another approach with the IProgress. Here's an example;

protected override void ProcessRecord()
{
    var progress = new Progress(errorRecord =>
    {
        WriteError(errorRecord);    
    });

    var errorRecords = Task.Run(async () => await ProcessRecordCore(progress));     
}

private async Task ProcessRecordCore(IProgress progress)
{
    var tasks = CreateProcessTasks();

    var results = await Task.WhenAll(tasks.ToArray());

    foreach (var result in results)
    {
        try
        {
            PostProcessTask(result);
        }
        catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
        {
            // do nothing if pipeline stops
        }
        catch (Exception e)
        {
            progress.Report(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
        }
    }       
}



One important aspect of this

Code Snippets

if (InputObject != null)
  foreach (TIn input in InputObject)
      tasks.Add(Task.Run(() => ProcessTask(input)));
else
  tasks.Add(Task.Run(() => ProcessTask()));
protected virtual IEnumerable<Task<TOut>> GenerateTasks()
{
  List<Task<TOut>> tasks = new List<Task<TOut>>();

  if (InputObject != null)
      foreach (TIn input in InputObject)
          tasks.Add(Task.Run(() => ProcessTask(input)));
  else
      tasks.Add(Task.Run(() => ProcessTask()));

  return tasks;
}
protected virtual IEnumerable<Task<TOut>> CreateProcessTasks()
{
    if (InputObject == null) 
    {
        yield return Task.Run(() => ProcessTask());
        yield break;
    }

    foreach (var input in InputObject)
    {
        yield return Task.Run(() => ProcessTask(input));
    }
}
var results = await Task.WhenAll(tasks.ToArray());
protected override void ProcessRecord()
{
    var errorRecords = Task.Run(async () => await ProcessRecordCore()).Result;

    foreach (var errorRecord in errorRecords)
    {
        WriteError(errorRecord);    
    }       
}

private async Task<BlockingCollection<ErrorRecord>> ProcessRecordCore()
{
    var tasks = CreateProcessTasks();

    var results = await Task.WhenAll(tasks.ToArray());

    var errorRecords = new BlockingCollection<ErrorRecord>();

    foreach (var result in results)
    {
        try
        {
            PostProcessTask(result);
        }
        catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
        {
            // do nothing if pipeline stops
        }
        catch (Exception e)
        {
            errorRecords.Add(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
        }
    }

    return errorRecords;
}

Context

StackExchange Code Review Q#146810, answer score: 3

Revisions (0)

No revisions yet.