patterncsharpMinor
Task based PowerShell cmdlet
Viewed 0 times
cmdletpowershelltaskbased
Problem
I am learning about Tasks and also writing PowerShell cmdlets in C#. I found that connecting to remote machines was very slow, so I wrote this bit of code to speed it up using Tasks. I am hoping to get some feedback regarding the parallel processing and any other pointers in general.
The code is committed to a Github repo.
TaskCmdlet.cs
```
using System;
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation;
using System.Threading;
using System.Threading.Tasks;
namespace PoshTasks.Cmdlets
{
public abstract class TaskCmdlet : Cmdlet where TIn : class
where TOut : class
{
#region Parameters
[Parameter(ValueFromPipeline = true)]
public TIn[] InputObject { get; set; }
#endregion
#region Abstract methods
///
/// Performs an action on
///
/// The to be processed; null if not processing input
/// A
protected abstract TOut ProcessTask(TIn input = null);
#endregion
#region Virtual methods
///
/// Generates a collection of tasks to be processed
///
/// A collection of tasks
protected virtual IEnumerable> GenerateTasks()
{
List> tasks = new List>();
if (InputObject != null)
foreach (TIn input in InputObject)
tasks.Add(Task.Run(() => ProcessTask(input)));
else
tasks.Add(Task.Run(() => ProcessTask()));
return tasks;
}
///
/// Performs the pipeline output for this cmdlet
///
///
protected virtual void PostProcessTask(TOut result)
{
WriteObject(result, true);
}
#endregion
#region Processing
///
/// Processes cmdlet operation
///
protected override void ProcessRecord()
{
IEnumerable> tasks = GenerateTasks();
The code is committed to a Github repo.
TaskCmdlet.cs
```
using System;
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation;
using System.Threading;
using System.Threading.Tasks;
namespace PoshTasks.Cmdlets
{
public abstract class TaskCmdlet : Cmdlet where TIn : class
where TOut : class
{
#region Parameters
[Parameter(ValueFromPipeline = true)]
public TIn[] InputObject { get; set; }
#endregion
#region Abstract methods
///
/// Performs an action on
///
/// The to be processed; null if not processing input
/// A
protected abstract TOut ProcessTask(TIn input = null);
#endregion
#region Virtual methods
///
/// Generates a collection of tasks to be processed
///
/// A collection of tasks
protected virtual IEnumerable> GenerateTasks()
{
List> tasks = new List>();
if (InputObject != null)
foreach (TIn input in InputObject)
tasks.Add(Task.Run(() => ProcessTask(input)));
else
tasks.Add(Task.Run(() => ProcessTask()));
return tasks;
}
///
/// Performs the pipeline output for this cmdlet
///
///
protected virtual void PostProcessTask(TOut result)
{
WriteObject(result, true);
}
#endregion
#region Processing
///
/// Processes cmdlet operation
///
protected override void ProcessRecord()
{
IEnumerable> tasks = GenerateTasks();
Solution
General conventions
The old good
There's not even single curly brace
In cases like this you can use
I think this method doesn't need to be
async/await
In order for the
I start with the
Where do I put it? I move this one to the
Notice that
There are two more methods that can be simplified. The first one is the
or you can go crazy and make it a one-liner:
The other one is the
IProgress interface
To see the errors right away you may try another approach with the
One important aspect of this
The old good
#region. Most people (including me) consider them rather bad then good. You should avoid them.if (InputObject != null)
foreach (TIn input in InputObject)
tasks.Add(Task.Run(() => ProcessTask(input)));
else
tasks.Add(Task.Run(() => ProcessTask()));There's not even single curly brace
{} ;-) I wouldn't complain if it was python but in C# you should always use them. Omitting them can cause a real headache.protected virtual IEnumerable> GenerateTasks()
{
List> tasks = new List>();
if (InputObject != null)
foreach (TIn input in InputObject)
tasks.Add(Task.Run(() => ProcessTask(input)));
else
tasks.Add(Task.Run(() => ProcessTask()));
return tasks;
}In cases like this you can use
yead return which greatly simplifies the code:protected virtual IEnumerable> CreateProcessTasks()
{
if (InputObject == null)
{
yield return Task.Run(() => ProcessTask());
yield break;
}
foreach (var input in InputObject)
{
yield return Task.Run(() => ProcessTask(input));
}
}I think this method doesn't need to be
virtual. Generating tasks in not something you'd like to implement in each derived class. Consider changing its name to CreateProcessTasks as this is what it does. Generate sounds like it would create some random tasks.async/await
In order for the
async/await to work you need to actually await something but I couldn't find it in your code. Let's try to fix that and introduce few other changes that make your code look better.I start with the
Interleaved method... and you actually don't need it. Everything it does can be reduced to a single line:var results = await Task.WhenAll(tasks.ToArray());Where do I put it? I move this one to the
ProcessRecordCore method that after this adjustment now looks like this:protected override void ProcessRecord()
{
var errorRecords = Task.Run(async () => await ProcessRecordCore()).Result;
foreach (var errorRecord in errorRecords)
{
WriteError(errorRecord);
}
}
private async Task> ProcessRecordCore()
{
var tasks = CreateProcessTasks();
var results = await Task.WhenAll(tasks.ToArray());
var errorRecords = new BlockingCollection();
foreach (var result in results)
{
try
{
PostProcessTask(result);
}
catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
{
// do nothing if pipeline stops
}
catch (Exception e)
{
errorRecords.Add(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
}
}
return errorRecords;
}Notice that
ProcessRecordCore it's now marked as async so you can await for it to complete and the ProcessRecord uses the .Wait() method.There are two more methods that can be simplified. The first one is the
ProcessTask method where you can use the ?: ternary operator and don't need the if.protected override ServiceController[] ProcessTask(string server)
{
var services = ServiceController.GetServices(server);
return
Name == null
? services
: services.Where(s => Name.Contains(s.DisplayName)).ToArray();
}or you can go crazy and make it a one-liner:
return services.Where(s => Name == null || Name.Contains(s.DisplayName)).ToArray();The other one is the
PostProcessTask method that could use some vars (like the rest of the code):protected override void PostProcessTask(ServiceController[] result)
{
var services = new List();
foreach (var service in result)
{
services.Add(new
{
Name = service.DisplayName,
Status = service.Status,
ComputerName = service.MachineName,
CanPause = service.CanPauseAndContinue
});
}
WriteObject(services, true);
}IProgress interface
To see the errors right away you may try another approach with the
IProgress. Here's an example;protected override void ProcessRecord()
{
var progress = new Progress(errorRecord =>
{
WriteError(errorRecord);
});
var errorRecords = Task.Run(async () => await ProcessRecordCore(progress));
}
private async Task ProcessRecordCore(IProgress progress)
{
var tasks = CreateProcessTasks();
var results = await Task.WhenAll(tasks.ToArray());
foreach (var result in results)
{
try
{
PostProcessTask(result);
}
catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
{
// do nothing if pipeline stops
}
catch (Exception e)
{
progress.Report(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
}
}
}One important aspect of this
Code Snippets
if (InputObject != null)
foreach (TIn input in InputObject)
tasks.Add(Task.Run(() => ProcessTask(input)));
else
tasks.Add(Task.Run(() => ProcessTask()));protected virtual IEnumerable<Task<TOut>> GenerateTasks()
{
List<Task<TOut>> tasks = new List<Task<TOut>>();
if (InputObject != null)
foreach (TIn input in InputObject)
tasks.Add(Task.Run(() => ProcessTask(input)));
else
tasks.Add(Task.Run(() => ProcessTask()));
return tasks;
}protected virtual IEnumerable<Task<TOut>> CreateProcessTasks()
{
if (InputObject == null)
{
yield return Task.Run(() => ProcessTask());
yield break;
}
foreach (var input in InputObject)
{
yield return Task.Run(() => ProcessTask(input));
}
}var results = await Task.WhenAll(tasks.ToArray());protected override void ProcessRecord()
{
var errorRecords = Task.Run(async () => await ProcessRecordCore()).Result;
foreach (var errorRecord in errorRecords)
{
WriteError(errorRecord);
}
}
private async Task<BlockingCollection<ErrorRecord>> ProcessRecordCore()
{
var tasks = CreateProcessTasks();
var results = await Task.WhenAll(tasks.ToArray());
var errorRecords = new BlockingCollection<ErrorRecord>();
foreach (var result in results)
{
try
{
PostProcessTask(result);
}
catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
{
// do nothing if pipeline stops
}
catch (Exception e)
{
errorRecords.Add(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
}
}
return errorRecords;
}Context
StackExchange Code Review Q#146810, answer score: 3
Revisions (0)
No revisions yet.